Tutorial 2a shows various ways to fit Bayesian binary regression models to discrete time-to-event data. The data are taken from the first experiment of Panis and Schmidt (2016), but using only the no-mask trials (with factor prime = blank, congruent, or incongruent). We perform some prior predictive checks to infer the prior distributions we will employ. We compare models using WAIC and LOO, and interpret parameter estimates for one model. We end by plotting the logit and cloglog link functions, and illustrate how various prior distributions on the logit and cloglog scales look on the original probability (i.e., hazard) scale.
pkg <- c("cmdstanr", "standist", "tidyverse", "RColorBrewer", "patchwork",
"brms", "tidybayes", "bayesplot", "future", "parallel", "modelr")
lapply(pkg, library, character.only = TRUE)
## This is cmdstanr version 0.8.1.9000
## - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
## - CmdStan path: /Users/spanis/.cmdstan/cmdstan-2.35.0
## - CmdStan version: 2.35.0
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr 1.1.4 ✔ readr 2.1.5
## ✔ forcats 1.0.0 ✔ stringr 1.5.1
## ✔ ggplot2 3.5.1 ✔ tibble 3.2.1
## ✔ lubridate 1.9.3 ✔ tidyr 1.3.1
## ✔ purrr 1.0.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
## Loading required package: Rcpp
##
## Loading 'brms' package (version 2.21.0). Useful instructions
## can be found by typing help('brms'). A more detailed introduction
## to the package is available through vignette('brms_overview').
##
##
## Attaching package: 'brms'
##
##
## The following object is masked from 'package:stats':
##
## ar
##
##
##
## Attaching package: 'tidybayes'
##
##
## The following objects are masked from 'package:brms':
##
## dstudent_t, pstudent_t, qstudent_t, rstudent_t
##
##
## This is bayesplot version 1.11.1
##
## - Online documentation and vignettes at mc-stan.org/bayesplot
##
## - bayesplot theme set to bayesplot::theme_default()
##
## * Does _not_ affect other ggplot2 plots
##
## * See ?bayesplot_theme_set for details on theme setting
##
##
## Attaching package: 'bayesplot'
##
##
## The following object is masked from 'package:brms':
##
## rhat
Set options.
options(brms.backend = "cmdstanr",
mc.cores = parallel::detectCores(),
future.fork.enable = TRUE,
future.rng.onMisuse = "ignore") ## automatically set in RStudio
supportsMulticore()
detectCores()
packageVersion("cmdstanr")
devtools::session_info("rstan")
theme settings for ggplot
theme_set(
theme_bw() +
theme(text = element_text(size = 22, face = "bold"),
title = element_text(size = 22, face = "bold"),
legend.position = "bottom")
)
## Set the amount of dodge in figures
pd <- position_dodge(0.7)
pd2 <- position_dodge(1)
ptb_data <- read_csv("Tutorial_1_descriptive_stats/data/inputfile_hazard_modeling.csv")
## Rows: 26602 Columns: 7
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (1): condition
## dbl (6): pid, bl, tr, trial, period, event
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
head(ptb_data)
## # A tibble: 6 × 7
## pid bl tr trial condition period event
## <dbl> <dbl> <dbl> <dbl> <chr> <dbl> <dbl>
## 1 1 2 4 4 congruent 1 0
## 2 1 2 4 4 congruent 2 0
## 3 1 2 4 4 congruent 3 0
## 4 1 2 4 4 congruent 4 0
## 5 1 2 4 4 congruent 5 0
## 6 1 2 4 4 congruent 6 0
Wrangle the data set.
ptb_data <- ptb_data %>%
# select analysis time range: (200,600] with 10 bins (time bin ranks 6 to 15)
filter(period > 5) %>%
# create categorical predictor for TIME named "timebin" with index coding
mutate(timebin = factor(period,levels=c(6:15)),
# create continuous predictor for TIME named "period_9", centered on bin 9,
period_9 = period - 9,
# create binary variables to indicate each bin
d6 = if_else(period == 6, 1, 0),
d7 = if_else(period == 7, 1, 0),
d8 = if_else(period == 8, 1, 0),
d9 = if_else(period == 9, 1, 0),
d10 = if_else(period == 10, 1, 0),
d11 = if_else(period == 11, 1, 0),
d12 = if_else(period == 12, 1, 0),
d13 = if_else(period == 13, 1, 0),
d14 = if_else(period == 14, 1, 0),
d15 = if_else(period == 15, 1, 0),
# create continuous predictor for trial number named "trial_c", centered on bin 1000 and rescale
trial_c = (trial - 1000)/1000,
# create categorical predictor for trial number named "stage" (early,middle,late) with index coding
stage = ifelse(trial <= 500, 1, ifelse(trial > 1000, 3, 2)),
stage = factor(stage, levels=c(1,2,3)),
# create factor "condition", with "blank" as the reference level
condition = factor(condition, labels = c("blank", "congruent","incongruent")),
# create categorical predictor "prime" with index-coding
prime = ifelse(condition=="blank",1, ifelse(condition=="congruent",2,3)),
prime = factor(prime,levels=c(1,2,3))) %>%
select(pid,event,trial,trial_c,stage,condition,prime,period,period_9,timebin,d6:d15)
head(ptb_data,n=17)
## # A tibble: 17 × 20
## pid event trial trial_c stage condition prime period period_9 timebin d6
## <dbl> <dbl> <dbl> <dbl> <fct> <fct> <fct> <dbl> <dbl> <fct> <dbl>
## 1 1 0 4 -0.996 1 congruent 2 6 -3 6 1
## 2 1 0 4 -0.996 1 congruent 2 7 -2 7 0
## 3 1 0 4 -0.996 1 congruent 2 8 -1 8 0
## 4 1 0 4 -0.996 1 congruent 2 9 0 9 0
## 5 1 0 4 -0.996 1 congruent 2 10 1 10 0
## 6 1 0 4 -0.996 1 congruent 2 11 2 11 0
## 7 1 1 4 -0.996 1 congruent 2 12 3 12 0
## 8 1 0 5 -0.995 1 incongru… 3 6 -3 6 1
## 9 1 0 5 -0.995 1 incongru… 3 7 -2 7 0
## 10 1 0 5 -0.995 1 incongru… 3 8 -1 8 0
## 11 1 0 5 -0.995 1 incongru… 3 9 0 9 0
## 12 1 0 5 -0.995 1 incongru… 3 10 1 10 0
## 13 1 0 5 -0.995 1 incongru… 3 11 2 11 0
## 14 1 0 5 -0.995 1 incongru… 3 12 3 12 0
## 15 1 0 5 -0.995 1 incongru… 3 13 4 13 0
## 16 1 0 5 -0.995 1 incongru… 3 14 5 14 0
## 17 1 1 5 -0.995 1 incongru… 3 15 6 15 0
## # ℹ 9 more variables: d7 <dbl>, d8 <dbl>, d9 <dbl>, d10 <dbl>, d11 <dbl>,
## # d12 <dbl>, d13 <dbl>, d14 <dbl>, d15 <dbl>
The first model we fit is a “random intercepts” model, where we fit a single grand intercept for each timebin and add random intercepts that vary between participants. This constitutes a general specification of TIME, and can be used if we do not want to make assumptions about how cloglog-hazard changes over time (within a trial).
There are two ways to implement such a model. First, we can use the index-coding approach, which provides an intercept for each level of TIME (variable timebin in ptb_data).
Prepare the data file, and specify priors. The skew_normal prior is set for each grand intercept on the cloglog scale, and should reflect our prior beliefs.
data_M0i <- ptb_data %>% select(pid, event, timebin)
priors_M0i <- c(
set_prior("skew_normal(-1,1,-2)", class = "b"),
set_prior("normal(0, 1)", class = "sd"),
set_prior("lkj(2)", class = "cor")
)
Perform a prior predictive check to see if the binary observations generated from the sprecified prior distributions reflects our prior beliefs. First, sample the prior distributions using sample_prior=“only”.
plan(multicore)
model_M0i_prior <-
brm(data = data_M0i,
family = bernoulli(link="cloglog"),
formula = event ~ 0 + timebin + (0 + timebin | pid),
prior = priors_M0i,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
sample_prior = "only",
file = "Tutorial_2_Bayesian/models/model_M0i_prior")
model_M0i_prior <- readRDS("Tutorial_2_Bayesian/models/model_M0i_prior.rds")
Next, use them to predict prior data. To better understand what the function add_predicted_draws is doing (which we will typically use), we first simulate prior data for one timebin manually.
# prior predictive check for 1 time bin
# see http://bruno.nicenboim.me/bayescogsci/
# cloglog and inverse-cloglog link functions
inverse_cloglog <- function(x) {return(1-(exp(-1*exp(x))))}
cloglog <- function(x) {return(log(-1*log(1-x)))}
# Specify
N_samples <- 1000 # the number of samples from the skew_normal
N_obs <- 100 # the number of simulated observations per sample
set.seed(1)
# function to generate binary observations from samples
cloglog_predictive_distribution <- function(samples, Nobs) {
# empty data frame with headers
df_pred <- tibble(trial = numeric(0),
y_pred = numeric(0),
iter = numeric(0))
# i iterates from 1 to the length of samples
for (i in seq_along(samples)) {
cloglog_haz <- samples[i]
df_pred <- bind_rows(df_pred,
tibble(trial = 1:N_obs,
y_pred = rbinom(Nobs, 1, inverse_cloglog(cloglog_haz)),
iter = i))
}
df_pred
}
# sample prior cloglog-hazard values for 1 bin
cloglog_samples <- rskew_normal(N_samples,-1,1,-2)
# apply the function
prior_pred <- cloglog_predictive_distribution(cloglog_samples,N_obs)
# calculate hazard as the mean over binary observations
prior_pred_haz <- prior_pred %>%
group_by(iter) %>%
summarise(pred_haz = mean(y_pred))
# plot
ggplot(prior_pred_haz, aes(x=pred_haz)) +
geom_histogram(binwidth=.01) +
scale_x_continuous(limits = c(0,1)) +
labs(x = "simulated hazard",
title = "Prior predictive distribution",
subtitle = str_c(N_samples," samples; ",N_obs," observations per sample"))
## Warning: Removed 2 rows containing missing values or values outside the scale range
## (`geom_bar()`).
The shape of the prior predictive distribution looks as expected (compare with Figure X in the supplementary material): hazard values below .5 are more likely than hazard values above .5.
Now use add_predicted_draws() to simulate prior-based data for 10 bins and 6 subjects, while take into account samples from the priors for the standard deviation of the random effects, and correlations between parameters.
# Generate prior predictive hazard functions for 6 participants
newdata_prior = tibble(timebin = 6:15) %>%
expand_grid(pid = 1:6)
# Specify
N_obs = 100
set.seed(2)
# prepare columns of data set
df_pred <- tibble(timebin = integer(0),
pid = integer(0),
.row = integer(0),
.chain = integer(0),
.iteration = integer(0),
.draw = integer(0),
.prediction = integer(0),
obs = integer(0))
# Call add_predicted_draws() for each simulated observation
for(i in 1:N_obs){
prior_pred <- add_predicted_draws(model_M0i_prior,
newdata=newdata_prior,
summary=F,
ndraws=NULL) %>%
mutate(obs = i)
df_pred <- bind_rows(df_pred, prior_pred)
}
# calculate hazard per draw (average across observations)
df_pred_haz <- df_pred %>%
group_by(pid,timebin,.draw) %>%
summarise(pred_haz = mean(.prediction))
# plot 40 draws per participant
ggplot(data = df_pred_haz %>% filter(.draw < 41),
aes(x=timebin, y=pred_haz, group=.draw)) +
geom_line(color="black") +
geom_line(data=df_pred_haz %>% filter(.draw == 20),
color="red",
linewidth=2) +
scale_y_continuous(limits = c(0,1)) +
scale_x_continuous(limits = c(6,15), breaks = c(6:15)) +
labs(x = "time bin",
y = "simulated hazard",
title = "Prior predictive distributions",
subtitle = str_c("8000 samples; ",N_obs," observations per sample")) +
facet_wrap(~pid)
Now we can fit model M0i to get the posterior distributions.
plan(multicore)
model_M0i <-
brm(data = data_M0i,
family = bernoulli(link="cloglog"),
formula = event ~ 0 + timebin + (0 + timebin | pid),
prior = priors_M0i,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
file = "Tutorial_2_Bayesian/models/model_M0i")
This took 28 minutes on a MacBook Pro (Sonoma 14.6.1 OS, 18GB Memory, M3 Pro Chip).
model_M0i <- readRDS("Tutorial_2_Bayesian/models/model_M0i.rds")
summary(model_M0i)
## Family: bernoulli
## Links: mu = cloglog
## Formula: event ~ 0 + timebin + (0 + timebin | pid)
## Data: data_M0i (Number of observations: 12840)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Multilevel Hyperparameters:
## ~pid (Number of levels: 6)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(timebin6) 0.87 0.32 0.42 1.66 1.00 4730
## sd(timebin7) 0.69 0.26 0.33 1.35 1.00 5132
## sd(timebin8) 0.54 0.22 0.25 1.10 1.00 5334
## sd(timebin9) 0.50 0.20 0.25 1.01 1.00 4992
## sd(timebin10) 0.62 0.24 0.31 1.23 1.00 4881
## sd(timebin11) 0.49 0.22 0.21 1.05 1.00 4474
## sd(timebin12) 0.29 0.17 0.06 0.71 1.00 3888
## sd(timebin13) 0.23 0.19 0.01 0.73 1.00 3395
## sd(timebin14) 0.34 0.28 0.01 1.05 1.00 3660
## sd(timebin15) 0.58 0.35 0.05 1.39 1.00 3564
## cor(timebin6,timebin7) 0.25 0.25 -0.26 0.70 1.00 8616
## cor(timebin6,timebin8) 0.05 0.25 -0.44 0.53 1.00 8715
## cor(timebin7,timebin8) 0.22 0.25 -0.29 0.67 1.00 9512
## cor(timebin6,timebin9) -0.04 0.24 -0.51 0.43 1.00 10869
## cor(timebin7,timebin9) 0.01 0.24 -0.46 0.46 1.00 9110
## cor(timebin8,timebin9) 0.15 0.25 -0.35 0.61 1.00 9319
## cor(timebin6,timebin10) -0.09 0.25 -0.55 0.40 1.00 10723
## cor(timebin7,timebin10) -0.11 0.24 -0.55 0.38 1.00 8918
## cor(timebin8,timebin10) 0.03 0.25 -0.44 0.51 1.00 8862
## cor(timebin9,timebin10) 0.25 0.25 -0.27 0.70 1.00 7887
## cor(timebin6,timebin11) 0.02 0.25 -0.48 0.50 1.00 9775
## cor(timebin7,timebin11) -0.04 0.25 -0.51 0.42 1.00 10157
## cor(timebin8,timebin11) -0.00 0.25 -0.48 0.48 1.00 8397
## cor(timebin9,timebin11) 0.13 0.25 -0.38 0.59 1.00 8883
## cor(timebin10,timebin11) 0.24 0.25 -0.27 0.69 1.00 7293
## cor(timebin6,timebin12) -0.02 0.26 -0.52 0.48 1.00 13590
## cor(timebin7,timebin12) -0.11 0.26 -0.60 0.40 1.00 10886
## cor(timebin8,timebin12) -0.09 0.26 -0.58 0.43 1.00 9700
## cor(timebin9,timebin12) 0.08 0.26 -0.42 0.55 1.00 8514
## cor(timebin10,timebin12) 0.18 0.26 -0.35 0.65 1.00 7826
## cor(timebin11,timebin12) 0.17 0.26 -0.35 0.65 1.00 7448
## cor(timebin6,timebin13) -0.10 0.27 -0.60 0.45 1.00 14128
## cor(timebin7,timebin13) -0.08 0.27 -0.59 0.45 1.00 12113
## cor(timebin8,timebin13) -0.01 0.27 -0.53 0.51 1.00 11687
## cor(timebin9,timebin13) 0.07 0.27 -0.46 0.57 1.00 9848
## cor(timebin10,timebin13) 0.12 0.27 -0.42 0.61 1.00 8309
## cor(timebin11,timebin13) 0.10 0.27 -0.43 0.60 1.00 7345
## cor(timebin12,timebin13) 0.06 0.27 -0.47 0.59 1.00 6174
## cor(timebin6,timebin14) -0.01 0.27 -0.52 0.51 1.00 14558
## cor(timebin7,timebin14) -0.06 0.27 -0.58 0.47 1.00 12396
## cor(timebin8,timebin14) -0.09 0.27 -0.59 0.46 1.00 10910
## cor(timebin9,timebin14) -0.06 0.27 -0.57 0.47 1.00 9645
## cor(timebin10,timebin14) 0.02 0.26 -0.49 0.52 1.00 9186
## cor(timebin11,timebin14) 0.07 0.27 -0.45 0.57 1.00 7755
## cor(timebin12,timebin14) 0.05 0.27 -0.48 0.56 1.00 6286
## cor(timebin13,timebin14) 0.03 0.28 -0.50 0.56 1.00 5495
## cor(timebin6,timebin15) -0.11 0.27 -0.60 0.43 1.00 13387
## cor(timebin7,timebin15) -0.13 0.26 -0.60 0.40 1.00 11988
## cor(timebin8,timebin15) -0.05 0.26 -0.55 0.47 1.00 10620
## cor(timebin9,timebin15) 0.05 0.26 -0.46 0.55 1.00 9850
## cor(timebin10,timebin15) 0.14 0.26 -0.38 0.61 1.00 7818
## cor(timebin11,timebin15) 0.12 0.26 -0.41 0.60 1.00 7565
## cor(timebin12,timebin15) 0.09 0.26 -0.42 0.58 1.00 5803
## cor(timebin13,timebin15) 0.08 0.28 -0.46 0.59 1.00 5481
## cor(timebin14,timebin15) 0.06 0.28 -0.49 0.57 1.00 5351
## Tail_ESS
## sd(timebin6) 6304
## sd(timebin7) 5876
## sd(timebin8) 6477
## sd(timebin9) 5986
## sd(timebin10) 5771
## sd(timebin11) 5748
## sd(timebin12) 4142
## sd(timebin13) 4078
## sd(timebin14) 4433
## sd(timebin15) 3487
## cor(timebin6,timebin7) 6385
## cor(timebin6,timebin8) 6472
## cor(timebin7,timebin8) 6697
## cor(timebin6,timebin9) 6032
## cor(timebin7,timebin9) 6570
## cor(timebin8,timebin9) 6832
## cor(timebin6,timebin10) 6492
## cor(timebin7,timebin10) 6365
## cor(timebin8,timebin10) 6634
## cor(timebin9,timebin10) 6360
## cor(timebin6,timebin11) 6306
## cor(timebin7,timebin11) 6681
## cor(timebin8,timebin11) 6697
## cor(timebin9,timebin11) 6965
## cor(timebin10,timebin11) 6541
## cor(timebin6,timebin12) 6191
## cor(timebin7,timebin12) 6203
## cor(timebin8,timebin12) 5960
## cor(timebin9,timebin12) 6479
## cor(timebin10,timebin12) 6293
## cor(timebin11,timebin12) 7097
## cor(timebin6,timebin13) 5524
## cor(timebin7,timebin13) 6296
## cor(timebin8,timebin13) 6521
## cor(timebin9,timebin13) 6682
## cor(timebin10,timebin13) 6728
## cor(timebin11,timebin13) 6450
## cor(timebin12,timebin13) 6331
## cor(timebin6,timebin14) 6267
## cor(timebin7,timebin14) 6159
## cor(timebin8,timebin14) 6044
## cor(timebin9,timebin14) 6626
## cor(timebin10,timebin14) 6806
## cor(timebin11,timebin14) 7168
## cor(timebin12,timebin14) 6536
## cor(timebin13,timebin14) 5750
## cor(timebin6,timebin15) 5743
## cor(timebin7,timebin15) 6969
## cor(timebin8,timebin15) 6650
## cor(timebin9,timebin15) 5962
## cor(timebin10,timebin15) 6618
## cor(timebin11,timebin15) 6776
## cor(timebin12,timebin15) 6234
## cor(timebin13,timebin15) 6398
## cor(timebin14,timebin15) 6466
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## timebin6 -2.86 0.38 -3.51 -1.99 1.00 3227 4065
## timebin7 -2.21 0.29 -2.74 -1.57 1.00 4116 4841
## timebin8 -1.82 0.23 -2.28 -1.34 1.00 4281 5439
## timebin9 -1.26 0.21 -1.71 -0.85 1.00 4732 5220
## timebin10 -0.90 0.25 -1.45 -0.45 1.00 5264 5654
## timebin11 -0.76 0.22 -1.25 -0.40 1.00 5470 4661
## timebin12 -0.64 0.16 -1.00 -0.37 1.00 5514 4090
## timebin13 -0.64 0.15 -0.99 -0.38 1.00 5495 3650
## timebin14 -0.93 0.22 -1.44 -0.56 1.00 7312 4539
## timebin15 -1.07 0.29 -1.71 -0.55 1.00 7745 5047
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
formula(model_M0i)
## event ~ 0 + timebin + (0 + timebin | pid)
plot(model_M0i)
Second, we can fit the same model using a dummy coding approach, as follows.
Prepare the data file, and specify priors.
# select bin indicator variables
data_M0d <- ptb_data %>% select(pid, event, d6:d8, d10:d15)
priors_M0d <- c(
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d6"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d7"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d8"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d10"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d11"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d12"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d13"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d14"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "d15"),
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "Intercept"),
set_prior("normal(0, 1)", class = "sd"),
set_prior("lkj(2)", class = "cor")
)
Fit the model with dummy coding (M0d).
plan(multicore)
model_M0d <-
brm(data = data_M0d,
family = bernoulli(link="cloglog"),
formula = event ~ 0 + d6 + d7 + d8 + Intercept + d10 + d11 + d12 + d13 + d14 + d15 +
(d6 + d7 + d8 + 1 + d10 + d11 + d12 + d13 + d14 + d15 | pid),
prior = priors_M0d,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
file = "Tutorial_2_Bayesian/models/model_M0d")
M0d took about 77 minutes.
model_M0d <- readRDS("Tutorial_2_Bayesian/models/model_M0d.rds")
formula(model_M0d)
## event ~ 0 + d6 + d7 + d8 + Intercept + d10 + d11 + d12 + d13 + d14 + d15 + (d6 + d7 + d8 + 1 + d10 + d11 + d12 + d13 + d14 + d15 | pid)
fixef(model_M0d)
## Estimate Est.Error Q2.5 Q97.5
## d6 -1.9332829 0.4321629 -2.787111 -1.0573393
## d7 -1.3389718 0.3619834 -2.114674 -0.6658691
## d8 -0.8722236 0.2800040 -1.516462 -0.3925590
## Intercept -1.1067777 0.2591651 -1.664300 -0.6149630
## d10 -0.2136246 0.4512951 -1.255833 0.4375740
## d11 -0.2686736 0.4755249 -1.336390 0.4656902
## d12 -0.1743704 0.4985607 -1.289036 0.5779741
## d13 -0.1899364 0.5138580 -1.299201 0.6065687
## d14 -0.7313939 0.4979868 -1.798154 0.1537556
## d15 -0.6647565 0.4946424 -1.724210 0.1838370
Third, we can make assumptions to simplify the model. Instead of using a general specification of TIME (bin rank), we can treat TIME as a continuous variable and make assumptions about how cloglog-hazard changes over time within a trial. For example, we might assume that cloglog-hazard changes in a linear (“period_9”) + quadratic (“period_9_sq”) fashion over time bins within a trial.
Prepare the data file, and specify priors.
data_M0c <- ptb_data %>%
select(pid, event, period_9) %>%
mutate(period_9_sq = period_9^2)
priors_M0c <- c(
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "Intercept"),
set_prior("normal(0.1,.1)", class = "b", coef = "period_9"),
set_prior("normal(0,.06)", class= "b", coef="period_9_sq"),
set_prior("normal(0, 1)", class = "sd"),
set_prior("lkj(2)", class = "cor")
)
Perform a prior predictive check to see if the binary observations generated from the sprecified prior distributions reflect our prior beliefs. The following code allows to change the coefficients and generate (cloglog-)hazard functions. Once you have found the coefficients for period_9 and period_9_sq that generate a variety of hazard functions spanning your prior beliefs, you can then use them to estimate a mean and standard deviation for each coefficient.
dat <- tibble(x = -3:6,
x2 = x^2,
cloglog1 = -1.8 + 0.3*x - 0.1*x2,
cloglog2 = -1 + 0.4*x - 0.1*x2,
cloglog3 = -3 - 0.1*x + 0.12*x2,
cloglog4 = -0.40 + 0.2*x - 0.04*x2,
haz1 = 1 - exp(-1*exp(cloglog1)),
haz2 = 1 - exp(-1*exp(cloglog2)),
haz3 = 1 - exp(-1*exp(cloglog3)),
haz4 = 1 - exp(-1*exp(cloglog4)))
p1<-ggplot(dat, aes(x=x)) +
geom_line(aes(y=cloglog1), color="black") +
geom_line(aes(y=cloglog2), color="green") +
geom_line(aes(y=cloglog3), color="red") +
geom_line(aes(y=cloglog4), color="blue") +
scale_y_continuous(limits=c(-6,2)) +
scale_x_continuous(breaks=c(-3:6)) +
labs(x = "timebin", y = "cloglog-hazard")
p2<-ggplot(dat, aes(x=x))+
geom_line(aes(y=haz1), color="black") +
geom_line(aes(y=haz2), color="green") +
geom_line(aes(y=haz3), color="red") +
geom_line(aes(y=haz4), color="blue") +
scale_y_continuous(limits=c(0,1)) +
scale_x_continuous(breaks=c(-3:6)) +
labs(x = "timebin", y = "hazard")
p1|p2
First, sample the prior distributions, and set the variance of the normal prior for class “sd” to a low value, to minimize the effect of the random variables while inspecting the relation between the selected priors for period_9 and period_9_sq and the shape of the prior predictive distributions.
priors_M0c_ppc <- c(
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "Intercept"),
set_prior("normal(0.1,.1)", class = "b", coef = "period_9"),
set_prior("normal(0,.06)", class= "b", coef="period_9_sq"),
set_prior("normal(0, .1)", class = "sd"), # set standard deviation to a small value for now
set_prior("lkj(2)", class = "cor")
)
plan(multicore)
model_M0c_prior <-
brm(data = data_M0c,
family = bernoulli(link="cloglog"),
formula = event ~ 0 + Intercept + period_9 + period_9_sq +
(0 + Intercept + period_9 + period_9_sq | pid),
prior = priors_M0c_ppc,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
sample_prior = "only",
file = "Tutorial_2_Bayesian/models/model_M0c_prior")
model_M0c_prior <- readRDS("Tutorial_2_Bayesian/models/model_M0c_prior.rds")
Check the prior distributions.
# extract prior draws
post <- as_draws_df(model_M0c_prior) %>%
select(-lp__) %>%
as_tibble()
tidy_prior <- post %>%
select(starts_with("b_"), .chain, .iteration, .draw) %>%
rename(chain=.chain, iter=.iteration, draw=.draw) %>%
pivot_longer(-c(chain, draw, iter))
# plot
ggplot(tidy_prior, aes(x = name, y = value, fill = name)) +
stat_halfeye(alpha=0.7,
point_interval = "median_qi",
.width=c(.8,.99),
show.legend = F) +
labs(x = "parameter",
title = "Samples from the prior distributions")
Generate prior predictive check for each participant.
# Generate prior predictive hazard functions for 6 participants
# Specify
N_obs = 100 # increase for smoother results
set.seed(1)
# make new data
newdata_prior = tibble(period_9 = c(-3:6)) %>%
mutate(period_9_sq = period_9^2) %>%
expand_grid(pid = 1:6)
# prepare columns of data set with predictions
df_pred <- tibble(period_9 = integer(0),
period_9_sq = integer(0),
pid = integer(0),
.row = integer(0),
.chain = integer(0),
.iteration = integer(0),
.draw = integer(0),
.prediction = integer(0),
obs = integer(0))
# extract predictions
for(i in 1:N_obs){
prior_pred <- add_predicted_draws(model_M0c_prior,
newdata = newdata_prior,
summary=F,
ndraws=NULL) %>%
mutate(obs = i)
df_pred <- bind_rows(df_pred, prior_pred)
}
# calculate hazard per draw (average across obs)
df_pred_haz <- df_pred %>%
group_by(pid,period_9,.draw) %>%
summarise(pred_haz = mean(.prediction))
# plot 200 draws per participant, highlighting three
sel = sample(1:200,3,replace=F)
ggplot(df_pred_haz %>%
filter(.draw <200), aes(x=period_9, y=pred_haz, group=.draw)) +
geom_line(color="black") +
# highlight
geom_line(data=df_pred_haz %>% filter(.draw ==sel[1]),
color="red", linewidth=2) +
geom_line(data=df_pred_haz %>% filter(.draw == sel[2]),
color="green", linewidth=2) +
geom_line(data=df_pred_haz %>% filter(.draw == sel[3]),
color="blue", linewidth=2) +
scale_y_continuous(limits = c(0,1)) +
scale_x_continuous(limits = c(-3,6), breaks = c(-3:6)) +
facet_wrap(~pid)
A wide range of differently shaped hazard functions can be observed, and early bins are unlikely to have high hazard values, consistent with our prior beliefs. Note that some simulated hazard functions start off with a high hazard in bin -3. These reflect a combination of negative slopes and a positive coefficient for period_9_sq. These possibilities are required to encompass hazard functions that initially decrease and then increase, as observed in the empirical hazard functions for incongruent primes.
Now increase the variance again of the normal prior for class “sd”, and fit model M0c to get the posterior distributions.
plan(multicore)
model_M0c <-
brm(data = data_M0c,
family = bernoulli(link="cloglog"),
formula = event ~ 0 + Intercept + period_9 + period_9_sq +
(0 + Intercept + period_9 + period_9_sq | pid),
prior = priors_M0c,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
file = "Tutorial_2_Bayesian/models/model_M0c")
Model_M0c took about 71 minutes to run.
model_M0c <- readRDS("Tutorial_2_Bayesian/models/model_M0c.rds")
summary(model_M0c)
## Family: bernoulli
## Links: mu = cloglog
## Formula: event ~ 0 + Intercept + period_9 + period_9_sq + (0 + Intercept + period_9 + period_9_sq | pid)
## Data: data_M0c (Number of observations: 12840)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Multilevel Hyperparameters:
## ~pid (Number of levels: 6)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(Intercept) 0.48 0.20 0.23 0.97 1.00 3727
## sd(period_9) 0.35 0.16 0.15 0.76 1.00 2875
## sd(period_9_sq) 0.06 0.03 0.03 0.13 1.00 3189
## cor(Intercept,period_9) 0.07 0.35 -0.61 0.70 1.00 2887
## cor(Intercept,period_9_sq) -0.29 0.31 -0.80 0.39 1.00 5210
## cor(period_9,period_9_sq) -0.51 0.32 -0.94 0.24 1.00 4871
## Tail_ESS
## sd(Intercept) 4784
## sd(period_9) 4107
## sd(period_9_sq) 3940
## cor(Intercept,period_9) 4266
## cor(Intercept,period_9_sq) 5562
## cor(period_9,period_9_sq) 4983
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept -1.24 0.22 -1.71 -0.84 1.00 2290 3759
## period_9 0.24 0.10 0.03 0.42 1.00 3740 4053
## period_9_sq -0.05 0.03 -0.09 0.01 1.00 4074 4193
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
formula(model_M0c)
## event ~ 0 + Intercept + period_9 + period_9_sq + (0 + Intercept + period_9 + period_9_sq | pid)
The next models we fit also include our predictor variable prime type. First, we can use index coding for this categorical predictor.
Prepare the data file, and specify priors.
data_M1i <- ptb_data %>% select(pid, event, timebin, prime)
priors_M1i <- c(
set_prior("skew_normal(-1,1,-2)", class = "b"),
set_prior("normal(0, 1)", class = "sd"),
set_prior("lkj(2)", class = "cor")
)
Fit model M1i. The interaction between timebin and prime will specify 30 grand intercepts.
plan(multicore)
model_M1i <-
brm(data = data_M1i,
family = bernoulli(link="cloglog"),
formula = event ~ 0 + timebin:prime +
(0 + timebin:prime | pid),
prior = priors_M1i,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
file = "Tutorial_2_Bayesian/models/model_M1i")
Model M1i took about 124 minutes to fit.
model_M1i <- readRDS("Tutorial_2_Bayesian/models/model_M1i.rds")
fixef(model_M1i)
## Estimate Est.Error Q2.5 Q97.5
## timebin6:prime1 -4.4496709 0.7322019 -5.5186672 -2.6172545
## timebin7:prime1 -2.9756681 0.4525940 -3.7424620 -1.9015655
## timebin8:prime1 -1.9072521 0.2783905 -2.4772665 -1.3455780
## timebin9:prime1 -1.1509991 0.2466627 -1.6972177 -0.6858400
## timebin10:prime1 -0.6150386 0.2901719 -1.2785840 -0.1378814
## timebin11:prime1 -0.4684899 0.2254269 -1.0113615 -0.1142685
## timebin12:prime1 -0.4547018 0.2177043 -0.9787083 -0.1061844
## timebin13:prime1 -0.6302033 0.3147734 -1.3489175 -0.1414447
## timebin14:prime1 -0.8417804 0.3360386 -1.6159770 -0.2757566
## timebin15:prime1 -1.2681006 0.4575596 -2.3015065 -0.4789412
## timebin6:prime2 -2.2635497 0.3177752 -2.8664268 -1.5812525
## timebin7:prime2 -1.5347939 0.2431984 -2.0381137 -1.0528518
## timebin8:prime2 -1.2777413 0.2531524 -1.8227535 -0.7926707
## timebin9:prime2 -0.8541505 0.3127625 -1.5475030 -0.3109109
## timebin10:prime2 -0.8885676 0.2999444 -1.5431482 -0.3650001
## timebin11:prime2 -0.8094231 0.2945537 -1.4762308 -0.3187765
## timebin12:prime2 -0.5831442 0.2450133 -1.1480717 -0.1736950
## timebin13:prime2 -0.9003810 0.3698885 -1.7097407 -0.2575006
## timebin14:prime2 -1.2750358 0.4208509 -2.1808997 -0.5112653
## timebin15:prime2 -1.3134203 0.4966061 -2.3831527 -0.4397186
## timebin6:prime3 -2.2520279 0.4140011 -3.0427777 -1.3898838
## timebin7:prime3 -2.0632488 0.4948714 -3.0413115 -1.0857520
## timebin8:prime3 -2.6500772 0.5383075 -3.6544872 -1.5199973
## timebin9:prime3 -2.4139511 0.3149837 -3.0239505 -1.7384690
## timebin10:prime3 -1.8710860 0.4376559 -2.8024945 -1.0426678
## timebin11:prime3 -1.3303723 0.3261964 -2.0448833 -0.7438644
## timebin12:prime3 -1.0256625 0.2455494 -1.5681517 -0.5927057
## timebin13:prime3 -0.9584634 0.2732235 -1.5733905 -0.4806148
## timebin14:prime3 -1.0662756 0.2444208 -1.6111005 -0.6284875
## timebin15:prime3 -1.1332311 0.3770779 -1.9308520 -0.4511871
formula(model_M1i)
## event ~ 0 + timebin:prime + (0 + timebin:prime | pid)
Second, if we want to make assumptions about (1) how hazard changes over TIME in the reference condition (blank prime), and (2) how the effect of congruent and incongruent primes change over TIME (relax the proportionality assumption), then we can switch to a dummy coding approach and treat TIME as a continuous variable. For example, we may assume that hazard can change in a linear + quadratic fashion over time for a blank prime, and that the effects of congruent and incongruent primes relative to blank change in a linear + quadratic fashion, and fit the following model (M1d).
data_M1d <- ptb_data %>% select(pid, event, period_9, condition) %>%
mutate(period_9_sq = period_9^2)
# check which priors to set for model_M1d
check_priors_M1d <- get_prior(formula = event ~ 0 + Intercept +
condition*period_9 +
condition*period_9_sq +
(1 + condition*period_9 +
condition*period_9_sq | pid),
data = data_M1d,
familiy = bernoulli(link="cloglog"))
priors_M1d <- c(
set_prior("skew_normal(-1,1,-2)", class = "b", coef = "Intercept"),
set_prior("normal(0.1,.1)", class = "b", coef = "period_9"),
set_prior("normal(0,.06)", class= "b", coef = "period_9_sq"),
set_prior("normal(0,.4)", class = "b", coef = "conditioncongruent"),
set_prior("normal(0,.4)", class = "b", coef = "conditionincongruent"),
set_prior("normal(0, 1)", class = "sd"),
set_prior("lkj(2)", class = "cor")
)
plan(multicore)
model_M1d <-
brm(data = data_M1d,
family = bernoulli(link="cloglog"),
event ~ 0 + Intercept +
condition*period_9 +
condition*period_9_sq +
(1 + condition*period_9 +
condition*period_9_sq | pid),
prior = priors_M1d,
chains = 4, cores = 4,
iter = 3000, warmup = 1000,
control = list(adapt_delta = 0.999,
step_size = 0.04,
max_treedepth = 12),
seed = 12, init = "0",
file = "Tutorial_2_Bayesian/models/model_M1d")
Model_M1d took about 145 minutes to run. Note that duplicate terms in the model formula (e.g., condition) are ignored.
model_M1d <- readRDS("Tutorial_2_Bayesian/models/model_M1d.rds")
summary(model_M1d)
## Family: bernoulli
## Links: mu = cloglog
## Formula: event ~ 0 + Intercept + condition * period_9 + condition * period_9_sq + (1 + condition * period_9 + condition * period_9_sq | pid)
## Data: data_M1d (Number of observations: 12840)
## Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
## total post-warmup draws = 8000
##
## Multilevel Hyperparameters:
## ~pid (Number of levels: 6)
## Estimate
## sd(Intercept) 0.69
## sd(conditioncongruent) 0.64
## sd(conditionincongruent) 0.54
## sd(period_9) 0.70
## sd(period_9_sq) 0.11
## sd(conditioncongruent:period_9) 0.14
## sd(conditionincongruent:period_9) 0.44
## sd(conditioncongruent:period_9_sq) 0.06
## sd(conditionincongruent:period_9_sq) 0.03
## cor(Intercept,conditioncongruent) -0.20
## cor(Intercept,conditionincongruent) -0.18
## cor(conditioncongruent,conditionincongruent) 0.19
## cor(Intercept,period_9) 0.03
## cor(conditioncongruent,period_9) 0.05
## cor(conditionincongruent,period_9) -0.19
## cor(Intercept,period_9_sq) -0.19
## cor(conditioncongruent,period_9_sq) 0.04
## cor(conditionincongruent,period_9_sq) 0.25
## cor(period_9,period_9_sq) -0.25
## cor(Intercept,conditioncongruent:period_9) 0.14
## cor(conditioncongruent,conditioncongruent:period_9) -0.07
## cor(conditionincongruent,conditioncongruent:period_9) -0.13
## cor(period_9,conditioncongruent:period_9) -0.02
## cor(period_9_sq,conditioncongruent:period_9) -0.09
## cor(Intercept,conditionincongruent:period_9) 0.21
## cor(conditioncongruent,conditionincongruent:period_9) -0.14
## cor(conditionincongruent,conditionincongruent:period_9) -0.25
## cor(period_9,conditionincongruent:period_9) 0.03
## cor(period_9_sq,conditionincongruent:period_9) -0.27
## cor(conditioncongruent:period_9,conditionincongruent:period_9) 0.16
## cor(Intercept,conditioncongruent:period_9_sq) 0.10
## cor(conditioncongruent,conditioncongruent:period_9_sq) -0.24
## cor(conditionincongruent,conditioncongruent:period_9_sq) -0.15
## cor(period_9,conditioncongruent:period_9_sq) 0.06
## cor(period_9_sq,conditioncongruent:period_9_sq) -0.12
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) -0.00
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) 0.12
## cor(Intercept,conditionincongruent:period_9_sq) 0.09
## cor(conditioncongruent,conditionincongruent:period_9_sq) -0.01
## cor(conditionincongruent,conditionincongruent:period_9_sq) 0.00
## cor(period_9,conditionincongruent:period_9_sq) -0.03
## cor(period_9_sq,conditionincongruent:period_9_sq) -0.03
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) -0.03
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) -0.03
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) 0.03
## Est.Error
## sd(Intercept) 0.23
## sd(conditioncongruent) 0.23
## sd(conditionincongruent) 0.24
## sd(period_9) 0.25
## sd(period_9_sq) 0.06
## sd(conditioncongruent:period_9) 0.11
## sd(conditionincongruent:period_9) 0.19
## sd(conditioncongruent:period_9_sq) 0.04
## sd(conditionincongruent:period_9_sq) 0.02
## cor(Intercept,conditioncongruent) 0.25
## cor(Intercept,conditionincongruent) 0.25
## cor(conditioncongruent,conditionincongruent) 0.26
## cor(Intercept,period_9) 0.27
## cor(conditioncongruent,period_9) 0.27
## cor(conditionincongruent,period_9) 0.27
## cor(Intercept,period_9_sq) 0.26
## cor(conditioncongruent,period_9_sq) 0.26
## cor(conditionincongruent,period_9_sq) 0.26
## cor(period_9,period_9_sq) 0.27
## cor(Intercept,conditioncongruent:period_9) 0.27
## cor(conditioncongruent,conditioncongruent:period_9) 0.27
## cor(conditionincongruent,conditioncongruent:period_9) 0.27
## cor(period_9,conditioncongruent:period_9) 0.30
## cor(period_9_sq,conditioncongruent:period_9) 0.28
## cor(Intercept,conditionincongruent:period_9) 0.26
## cor(conditioncongruent,conditionincongruent:period_9) 0.26
## cor(conditionincongruent,conditionincongruent:period_9) 0.27
## cor(period_9,conditionincongruent:period_9) 0.29
## cor(period_9_sq,conditionincongruent:period_9) 0.28
## cor(conditioncongruent:period_9,conditionincongruent:period_9) 0.27
## cor(Intercept,conditioncongruent:period_9_sq) 0.26
## cor(conditioncongruent,conditioncongruent:period_9_sq) 0.27
## cor(conditionincongruent,conditioncongruent:period_9_sq) 0.27
## cor(period_9,conditioncongruent:period_9_sq) 0.30
## cor(period_9_sq,conditioncongruent:period_9_sq) 0.28
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) 0.28
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) 0.27
## cor(Intercept,conditionincongruent:period_9_sq) 0.28
## cor(conditioncongruent,conditionincongruent:period_9_sq) 0.28
## cor(conditionincongruent,conditionincongruent:period_9_sq) 0.28
## cor(period_9,conditionincongruent:period_9_sq) 0.30
## cor(period_9_sq,conditionincongruent:period_9_sq) 0.29
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) 0.28
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) 0.28
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) 0.28
## l-95% CI
## sd(Intercept) 0.38
## sd(conditioncongruent) 0.32
## sd(conditionincongruent) 0.23
## sd(period_9) 0.34
## sd(period_9_sq) 0.04
## sd(conditioncongruent:period_9) 0.01
## sd(conditionincongruent:period_9) 0.22
## sd(conditioncongruent:period_9_sq) 0.02
## sd(conditionincongruent:period_9_sq) 0.00
## cor(Intercept,conditioncongruent) -0.64
## cor(Intercept,conditionincongruent) -0.64
## cor(conditioncongruent,conditionincongruent) -0.34
## cor(Intercept,period_9) -0.51
## cor(conditioncongruent,period_9) -0.48
## cor(conditionincongruent,period_9) -0.68
## cor(Intercept,period_9_sq) -0.66
## cor(conditioncongruent,period_9_sq) -0.47
## cor(conditionincongruent,period_9_sq) -0.30
## cor(period_9,period_9_sq) -0.73
## cor(Intercept,conditioncongruent:period_9) -0.40
## cor(conditioncongruent,conditioncongruent:period_9) -0.57
## cor(conditionincongruent,conditioncongruent:period_9) -0.62
## cor(period_9,conditioncongruent:period_9) -0.60
## cor(period_9_sq,conditioncongruent:period_9) -0.62
## cor(Intercept,conditionincongruent:period_9) -0.32
## cor(conditioncongruent,conditionincongruent:period_9) -0.61
## cor(conditionincongruent,conditionincongruent:period_9) -0.71
## cor(period_9,conditionincongruent:period_9) -0.52
## cor(period_9_sq,conditionincongruent:period_9) -0.75
## cor(conditioncongruent:period_9,conditionincongruent:period_9) -0.39
## cor(Intercept,conditioncongruent:period_9_sq) -0.42
## cor(conditioncongruent,conditioncongruent:period_9_sq) -0.72
## cor(conditionincongruent,conditioncongruent:period_9_sq) -0.65
## cor(period_9,conditioncongruent:period_9_sq) -0.52
## cor(period_9_sq,conditioncongruent:period_9_sq) -0.63
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) -0.54
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) -0.42
## cor(Intercept,conditionincongruent:period_9_sq) -0.47
## cor(conditioncongruent,conditionincongruent:period_9_sq) -0.55
## cor(conditionincongruent,conditionincongruent:period_9_sq) -0.53
## cor(period_9,conditionincongruent:period_9_sq) -0.59
## cor(period_9_sq,conditionincongruent:period_9_sq) -0.57
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) -0.57
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) -0.58
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) -0.53
## u-95% CI
## sd(Intercept) 1.29
## sd(conditioncongruent) 1.19
## sd(conditionincongruent) 1.15
## sd(period_9) 1.30
## sd(period_9_sq) 0.27
## sd(conditioncongruent:period_9) 0.39
## sd(conditionincongruent:period_9) 0.92
## sd(conditioncongruent:period_9_sq) 0.15
## sd(conditionincongruent:period_9_sq) 0.09
## cor(Intercept,conditioncongruent) 0.32
## cor(Intercept,conditionincongruent) 0.32
## cor(conditioncongruent,conditionincongruent) 0.66
## cor(Intercept,period_9) 0.55
## cor(conditioncongruent,period_9) 0.55
## cor(conditionincongruent,period_9) 0.36
## cor(Intercept,period_9_sq) 0.35
## cor(conditioncongruent,period_9_sq) 0.53
## cor(conditionincongruent,period_9_sq) 0.71
## cor(period_9,period_9_sq) 0.30
## cor(Intercept,conditioncongruent:period_9) 0.64
## cor(conditioncongruent,conditioncongruent:period_9) 0.47
## cor(conditionincongruent,conditioncongruent:period_9) 0.41
## cor(period_9,conditioncongruent:period_9) 0.55
## cor(period_9_sq,conditioncongruent:period_9) 0.47
## cor(Intercept,conditionincongruent:period_9) 0.67
## cor(conditioncongruent,conditionincongruent:period_9) 0.36
## cor(conditionincongruent,conditionincongruent:period_9) 0.32
## cor(period_9,conditionincongruent:period_9) 0.58
## cor(period_9_sq,conditionincongruent:period_9) 0.31
## cor(conditioncongruent:period_9,conditionincongruent:period_9) 0.67
## cor(Intercept,conditioncongruent:period_9_sq) 0.59
## cor(conditioncongruent,conditioncongruent:period_9_sq) 0.32
## cor(conditionincongruent,conditioncongruent:period_9_sq) 0.39
## cor(period_9,conditioncongruent:period_9_sq) 0.62
## cor(period_9_sq,conditioncongruent:period_9_sq) 0.44
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) 0.54
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) 0.63
## cor(Intercept,conditionincongruent:period_9_sq) 0.60
## cor(conditioncongruent,conditionincongruent:period_9_sq) 0.53
## cor(conditionincongruent,conditionincongruent:period_9_sq) 0.55
## cor(period_9,conditionincongruent:period_9_sq) 0.56
## cor(period_9_sq,conditionincongruent:period_9_sq) 0.53
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) 0.52
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) 0.51
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) 0.56
## Rhat
## sd(Intercept) 1.00
## sd(conditioncongruent) 1.00
## sd(conditionincongruent) 1.00
## sd(period_9) 1.00
## sd(period_9_sq) 1.00
## sd(conditioncongruent:period_9) 1.00
## sd(conditionincongruent:period_9) 1.00
## sd(conditioncongruent:period_9_sq) 1.00
## sd(conditionincongruent:period_9_sq) 1.00
## cor(Intercept,conditioncongruent) 1.00
## cor(Intercept,conditionincongruent) 1.00
## cor(conditioncongruent,conditionincongruent) 1.00
## cor(Intercept,period_9) 1.00
## cor(conditioncongruent,period_9) 1.00
## cor(conditionincongruent,period_9) 1.00
## cor(Intercept,period_9_sq) 1.00
## cor(conditioncongruent,period_9_sq) 1.00
## cor(conditionincongruent,period_9_sq) 1.00
## cor(period_9,period_9_sq) 1.00
## cor(Intercept,conditioncongruent:period_9) 1.00
## cor(conditioncongruent,conditioncongruent:period_9) 1.00
## cor(conditionincongruent,conditioncongruent:period_9) 1.00
## cor(period_9,conditioncongruent:period_9) 1.00
## cor(period_9_sq,conditioncongruent:period_9) 1.00
## cor(Intercept,conditionincongruent:period_9) 1.00
## cor(conditioncongruent,conditionincongruent:period_9) 1.00
## cor(conditionincongruent,conditionincongruent:period_9) 1.00
## cor(period_9,conditionincongruent:period_9) 1.00
## cor(period_9_sq,conditionincongruent:period_9) 1.00
## cor(conditioncongruent:period_9,conditionincongruent:period_9) 1.00
## cor(Intercept,conditioncongruent:period_9_sq) 1.00
## cor(conditioncongruent,conditioncongruent:period_9_sq) 1.00
## cor(conditionincongruent,conditioncongruent:period_9_sq) 1.00
## cor(period_9,conditioncongruent:period_9_sq) 1.00
## cor(period_9_sq,conditioncongruent:period_9_sq) 1.00
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) 1.00
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) 1.00
## cor(Intercept,conditionincongruent:period_9_sq) 1.00
## cor(conditioncongruent,conditionincongruent:period_9_sq) 1.00
## cor(conditionincongruent,conditionincongruent:period_9_sq) 1.00
## cor(period_9,conditionincongruent:period_9_sq) 1.00
## cor(period_9_sq,conditionincongruent:period_9_sq) 1.00
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) 1.00
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) 1.00
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) 1.00
## Bulk_ESS
## sd(Intercept) 4683
## sd(conditioncongruent) 5380
## sd(conditionincongruent) 5169
## sd(period_9) 4935
## sd(period_9_sq) 3561
## sd(conditioncongruent:period_9) 3420
## sd(conditionincongruent:period_9) 4539
## sd(conditioncongruent:period_9_sq) 4697
## sd(conditionincongruent:period_9_sq) 4483
## cor(Intercept,conditioncongruent) 8378
## cor(Intercept,conditionincongruent) 9291
## cor(conditioncongruent,conditionincongruent) 7770
## cor(Intercept,period_9) 5186
## cor(conditioncongruent,period_9) 5365
## cor(conditionincongruent,period_9) 5750
## cor(Intercept,period_9_sq) 7878
## cor(conditioncongruent,period_9_sq) 7548
## cor(conditionincongruent,period_9_sq) 7471
## cor(period_9,period_9_sq) 9193
## cor(Intercept,conditioncongruent:period_9) 12412
## cor(conditioncongruent,conditioncongruent:period_9) 10882
## cor(conditionincongruent,conditioncongruent:period_9) 8748
## cor(period_9,conditioncongruent:period_9) 9809
## cor(period_9_sq,conditioncongruent:period_9) 7462
## cor(Intercept,conditionincongruent:period_9) 10215
## cor(conditioncongruent,conditionincongruent:period_9) 9571
## cor(conditionincongruent,conditionincongruent:period_9) 8519
## cor(period_9,conditionincongruent:period_9) 9353
## cor(period_9_sq,conditionincongruent:period_9) 7466
## cor(conditioncongruent:period_9,conditionincongruent:period_9) 6306
## cor(Intercept,conditioncongruent:period_9_sq) 11657
## cor(conditioncongruent,conditioncongruent:period_9_sq) 10565
## cor(conditionincongruent,conditioncongruent:period_9_sq) 8972
## cor(period_9,conditioncongruent:period_9_sq) 8443
## cor(period_9_sq,conditioncongruent:period_9_sq) 7396
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) 7122
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) 6655
## cor(Intercept,conditionincongruent:period_9_sq) 13067
## cor(conditioncongruent,conditionincongruent:period_9_sq) 11932
## cor(conditionincongruent,conditionincongruent:period_9_sq) 9642
## cor(period_9,conditionincongruent:period_9_sq) 9465
## cor(period_9_sq,conditionincongruent:period_9_sq) 7959
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) 6518
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) 6166
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) 5359
## Tail_ESS
## sd(Intercept) 5707
## sd(conditioncongruent) 5907
## sd(conditionincongruent) 6537
## sd(period_9) 5077
## sd(period_9_sq) 5133
## sd(conditioncongruent:period_9) 3926
## sd(conditionincongruent:period_9) 6039
## sd(conditioncongruent:period_9_sq) 4264
## sd(conditionincongruent:period_9_sq) 5199
## cor(Intercept,conditioncongruent) 5827
## cor(Intercept,conditionincongruent) 6397
## cor(conditioncongruent,conditionincongruent) 6254
## cor(Intercept,period_9) 5697
## cor(conditioncongruent,period_9) 5592
## cor(conditionincongruent,period_9) 6004
## cor(Intercept,period_9_sq) 5976
## cor(conditioncongruent,period_9_sq) 6349
## cor(conditionincongruent,period_9_sq) 6709
## cor(period_9,period_9_sq) 6539
## cor(Intercept,conditioncongruent:period_9) 6497
## cor(conditioncongruent,conditioncongruent:period_9) 6442
## cor(conditionincongruent,conditioncongruent:period_9) 6945
## cor(period_9,conditioncongruent:period_9) 6295
## cor(period_9_sq,conditioncongruent:period_9) 6179
## cor(Intercept,conditionincongruent:period_9) 5501
## cor(conditioncongruent,conditionincongruent:period_9) 6716
## cor(conditionincongruent,conditionincongruent:period_9) 6434
## cor(period_9,conditionincongruent:period_9) 6861
## cor(period_9_sq,conditionincongruent:period_9) 6909
## cor(conditioncongruent:period_9,conditionincongruent:period_9) 6309
## cor(Intercept,conditioncongruent:period_9_sq) 6462
## cor(conditioncongruent,conditioncongruent:period_9_sq) 5969
## cor(conditionincongruent,conditioncongruent:period_9_sq) 6406
## cor(period_9,conditioncongruent:period_9_sq) 6458
## cor(period_9_sq,conditioncongruent:period_9_sq) 6426
## cor(conditioncongruent:period_9,conditioncongruent:period_9_sq) 6596
## cor(conditionincongruent:period_9,conditioncongruent:period_9_sq) 6644
## cor(Intercept,conditionincongruent:period_9_sq) 6296
## cor(conditioncongruent,conditionincongruent:period_9_sq) 5731
## cor(conditionincongruent,conditionincongruent:period_9_sq) 6102
## cor(period_9,conditionincongruent:period_9_sq) 6769
## cor(period_9_sq,conditionincongruent:period_9_sq) 6629
## cor(conditioncongruent:period_9,conditionincongruent:period_9_sq) 7087
## cor(conditionincongruent:period_9,conditionincongruent:period_9_sq) 7271
## cor(conditioncongruent:period_9_sq,conditionincongruent:period_9_sq) 6381
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat
## Intercept -1.17 0.30 -1.79 -0.63 1.00
## conditioncongruent 0.22 0.25 -0.31 0.69 1.00
## conditionincongruent -0.60 0.26 -1.03 -0.02 1.00
## period_9 0.19 0.10 -0.01 0.40 1.00
## period_9_sq -0.09 0.05 -0.17 0.01 1.00
## conditioncongruent:period_9 -0.46 0.10 -0.66 -0.27 1.00
## conditionincongruent:period_9 -0.50 0.23 -0.98 -0.04 1.00
## conditioncongruent:period_9_sq 0.08 0.04 0.00 0.16 1.00
## conditionincongruent:period_9_sq 0.14 0.02 0.09 0.18 1.00
## Bulk_ESS Tail_ESS
## Intercept 3226 4607
## conditioncongruent 4058 5278
## conditionincongruent 4243 5549
## period_9 9852 4863
## period_9_sq 4768 5664
## conditioncongruent:period_9 6075 4162
## conditionincongruent:period_9 6022 5379
## conditioncongruent:period_9_sq 5569 4748
## conditionincongruent:period_9_sq 6812 4938
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
formula(model_M1d)
## event ~ 0 + Intercept + condition * period_9 + condition * period_9_sq + (1 + condition * period_9 + condition * period_9_sq | pid)
One could also fit a model with the variables prime and period_9. However, to include interactions between an index-coded categorical variable and a continuous variable in brm(), one has to switch to the non-linear syntax, as illustrated in the following model formula.
formula = bf(event ~ 0 + a + b * period_9 + c * period_9_sq,
a ~ 0 + prime + (0 + prime |i| pid),
b ~ 0 + prime + (0 + prime |i| pid),
c ~ 0 + prime + (0 + prime |i| pid),
nl = TRUE)
The priors could be set as follows.
# check which priors to set
get_prior(formula=bf(event ~ 0 + a + b * period_9 + c * period_9_sq,
a ~ 0 + prime + (0 + prime |i| pid),
b ~ 0 + prime + (0 + prime |i| pid),
c ~ 0 + prime + (0 + prime |i| pid),
nl = TRUE),
data = ptb_data %>%
select(pid, event, period_9, prime) %>%
mutate(period_9_sq = period_9^2),
family = bernoulli(link="cloglog"))
## prior class coef group resp dpar nlpar lb ub source
## lkj(1) cor default
## lkj(1) cor pid (vectorized)
## (flat) b a default
## (flat) b prime1 a (vectorized)
## (flat) b prime2 a (vectorized)
## (flat) b prime3 a (vectorized)
## student_t(3, 0, 2.5) sd a 0 default
## student_t(3, 0, 2.5) sd pid a 0 (vectorized)
## student_t(3, 0, 2.5) sd prime1 pid a 0 (vectorized)
## student_t(3, 0, 2.5) sd prime2 pid a 0 (vectorized)
## student_t(3, 0, 2.5) sd prime3 pid a 0 (vectorized)
## (flat) b b default
## (flat) b prime1 b (vectorized)
## (flat) b prime2 b (vectorized)
## (flat) b prime3 b (vectorized)
## student_t(3, 0, 2.5) sd b 0 default
## student_t(3, 0, 2.5) sd pid b 0 (vectorized)
## student_t(3, 0, 2.5) sd prime1 pid b 0 (vectorized)
## student_t(3, 0, 2.5) sd prime2 pid b 0 (vectorized)
## student_t(3, 0, 2.5) sd prime3 pid b 0 (vectorized)
## (flat) b c default
## (flat) b prime1 c (vectorized)
## (flat) b prime2 c (vectorized)
## (flat) b prime3 c (vectorized)
## student_t(3, 0, 2.5) sd c 0 default
## student_t(3, 0, 2.5) sd pid c 0 (vectorized)
## student_t(3, 0, 2.5) sd prime1 pid c 0 (vectorized)
## student_t(3, 0, 2.5) sd prime2 pid c 0 (vectorized)
## student_t(3, 0, 2.5) sd prime3 pid c 0 (vectorized)
# set priors
priors <- c(
prior(skew_normal(-1,1,-2), class = b, nlpar = a),
prior(normal(0.1,.2), class = b, nlpar = b),
prior(normal(0,.7), class = b, nlpar = c),
prior(exponential(1), class = sd, group = pid, nlpar = a),
prior(exponential(1), class = sd, group = pid, nlpar = b),
prior(exponential(1), class = sd, group = pid, nlpar = c),
prior(lkj(2), class = cor, group = pid)
)
Up until now, we have been working with one time scale, TIME or bin rank within a trial. While many cognitive processes play out on this short time scale (milliseconds to seconds), some play out on longer time scales (minutes to hours), e.g., learning processes.
When we are interested in studying how the hazard of response occurence in our priming experiment also changes on a longer time scale, we can add trial number into the model formula. Here we simply illustrate some possibilities with index and reference coding.
First, we can categorize the predictor trial number by grouping trials in one of three “stages” of the experiment (stage 1 = trials 1 to 500; stage 2 = trials 501 to 1000; stage 3 = trials 1001 and later) and use index coding (variable “stage” in ptb_data).
event ~ 0 + timebin:prime:stage + (0 + timebin:prime:stage | pid)
## event ~ 0 + timebin:prime:stage + (0 + timebin:prime:stage |
## pid)
Second, we can treat trial number as a continuous variable, and make assumptions about the way hazard changes with trial number (e.g., linear, quadratic, etc.) for each combination of timebin and prime.
bf(event ~ 0 + a + b * trial_c + c * trial_c_sq,
a ~ 0 + timebin:prime + (0 + timebin:prime |i| pid),
b ~ 0 + timebin:prime + (0 + timebin:prime |i| pid),
c ~ 0 + timebin:prime + (0 + timebin:prime |i| pid),
nl = TRUE)
## event ~ 0 + a + b * trial_c + c * trial_c_sq
## a ~ 0 + timebin:prime + (0 + timebin:prime | i | pid)
## b ~ 0 + timebin:prime + (0 + timebin:prime | i | pid)
## c ~ 0 + timebin:prime + (0 + timebin:prime | i | pid)
Third, one can use dummy coding for prime type (variable “condition” in ptb_data), and treat TIME (“period_9”) and trial number (“trial_c”) as continous variables.
bf(event ~ 0 + Intercept +
condition*period_9*trial_c +
condition*period_9*trial_c_sq +
condition*period_9_sq*trial_c +
condition*period_9_sq*trial_c_sq +
(1 + condition*period_9*trial_c +
condition*period_9*trial_c^2) +
condition*period_9_sq*trial_c +
condition*period_9_sq*trial_c_sq|pid)
## event ~ 0 + Intercept + condition * period_9 * trial_c + condition * period_9 * trial_c_sq + condition * period_9_sq * trial_c + condition * period_9_sq * trial_c_sq + (1 + condition * period_9 * trial_c + condition * period_9 * trial_c^2) + condition * period_9_sq * trial_c + condition * period_9_sq * trial_c_sq | pid
The predictive accuracy of a set of models can be compared using WAIC and LOO.
model_M0i <- readRDS("Tutorial_2_Bayesian/models/model_M0i.rds")
model_M0d <- readRDS("Tutorial_2_Bayesian/models/model_M0d.rds")
model_M0c <- readRDS("Tutorial_2_Bayesian/models/model_M0c.rds")
model_M1i <- readRDS("Tutorial_2_Bayesian/models/model_M1i.rds")
model_M1d <- readRDS("Tutorial_2_Bayesian/models/model_M1d.rds")
Using WAIC and LOO for comparing nonnested models.
model_M0i <- add_criterion(model_M0i, c("loo", "waic"))
model_M0d <- add_criterion(model_M0d, c("loo", "waic"))
model_M0c <- add_criterion(model_M0c, c("loo", "waic"))
model_M1i <- add_criterion(model_M1i, c("loo", "waic"))
model_M1d <- add_criterion(model_M1d, c("loo", "waic"))
Compare all three models.
loo_compare(model_M0i, model_M0d, model_M0c, model_M1i, model_M1d, criterion = "loo") %>% print(simplify = F)
## elpd_diff se_diff elpd_loo se_elpd_loo p_loo se_p_loo looic
## model_M1i 0.0 0.0 -5111.6 61.7 128.1 3.0 10223.3
## model_M1d -37.4 15.0 -5149.0 62.4 49.8 1.5 10298.1
## model_M0i -420.5 27.3 -5532.1 63.7 49.5 1.0 11064.2
## model_M0d -423.8 27.5 -5535.4 63.8 54.4 1.2 11070.8
## model_M0c -430.7 28.9 -5542.3 63.6 18.3 0.5 11084.6
## se_looic
## model_M1i 123.3
## model_M1d 124.9
## model_M0i 127.5
## model_M0d 127.6
## model_M0c 127.1
loo_compare(model_M0i, model_M0d, model_M0c, model_M1i, model_M1d, criterion = "waic") %>% print(simplify = F)
## elpd_diff se_diff elpd_waic se_elpd_waic p_waic se_p_waic waic
## model_M1i 0.0 0.0 -5110.5 61.6 126.9 2.9 10221.0
## model_M1d -38.4 14.9 -5148.9 62.4 49.7 1.5 10297.8
## model_M0i -421.5 27.3 -5532.0 63.7 49.4 1.0 11064.0
## model_M0d -424.8 27.4 -5535.3 63.8 54.2 1.2 11070.6
## model_M0c -431.8 28.9 -5542.3 63.6 18.3 0.5 11084.6
## se_waic
## model_M1i 123.3
## model_M1d 124.9
## model_M0i 127.5
## model_M0d 127.6
## model_M0c 127.1
model_weights(model_M0i, model_M0d, model_M0c, model_M1i, model_M1d, weights = "loo") %>% round(digits = 3)
## model_M0i model_M0d model_M0c model_M1i model_M1d
## 0 0 0 1 0
model_weights(model_M0i, model_M0d, model_M0c, model_M1i, model_M1d, weights = "waic") %>% round(digits = 3)
## model_M0i model_M0d model_M0c model_M1i model_M1d
## 0 0 0 1 0
The Pareto k estimates can be displayed as follows.
loo(model_M1i)$diagnostics %>%
data.frame() %>%
# attach the `id` values
bind_cols(data_M0i) %>%
mutate(id = 1:n()) %>%
ggplot(aes(x = id, y = pareto_k)) +
geom_point(alpha = 3/4) +
geom_text(data = . %>% filter(pareto_k > .2),
aes(x = id + 2, label = id),
size = 3, hjust = 0) +
theme(panel.grid = element_blank())
Wrangle the posterior draws.
post <- as_draws_df(model_M1i) %>%
select(-lp__) %>%
as_tibble()
post_summary <- posterior_summary(model_M1i, robust = TRUE)
post_summary[1:30,]
## Estimate Est.Error Q2.5 Q97.5
## b_timebin6:prime1 -4.6142800 0.5956642 -5.5186672 -2.6172545
## b_timebin7:prime1 -3.0370900 0.3949646 -3.7424620 -1.9015655
## b_timebin8:prime1 -1.9069150 0.2510412 -2.4772665 -1.3455780
## b_timebin9:prime1 -1.1416050 0.2271491 -1.6972177 -0.6858400
## b_timebin10:prime1 -0.5818410 0.2643394 -1.2785840 -0.1378814
## b_timebin11:prime1 -0.4360185 0.1837638 -1.0113615 -0.1142685
## b_timebin12:prime1 -0.4257980 0.1702447 -0.9787083 -0.1061844
## b_timebin13:prime1 -0.5860980 0.2989470 -1.3489175 -0.1414447
## b_timebin14:prime1 -0.8060520 0.3050790 -1.6159770 -0.2757566
## b_timebin15:prime1 -1.2259300 0.4402358 -2.3015065 -0.4789412
## b_timebin6:prime2 -2.2776150 0.2899002 -2.8664268 -1.5812525
## b_timebin7:prime2 -1.5313350 0.2154885 -2.0381137 -1.0528518
## b_timebin8:prime2 -1.2685700 0.2303812 -1.8227535 -0.7926707
## b_timebin9:prime2 -0.8250290 0.2849350 -1.5475030 -0.3109109
## b_timebin10:prime2 -0.8661925 0.2812425 -1.5431482 -0.3650001
## b_timebin11:prime2 -0.7774375 0.2725523 -1.4762308 -0.3187765
## b_timebin12:prime2 -0.5605580 0.2181727 -1.1480717 -0.1736950
## b_timebin13:prime2 -0.8646980 0.3544140 -1.7097407 -0.2575006
## b_timebin14:prime2 -1.2479250 0.4003391 -2.1808997 -0.5112653
## b_timebin15:prime2 -1.2724250 0.4825048 -2.3831527 -0.4397186
## b_timebin6:prime3 -2.2622500 0.3914657 -3.0427777 -1.3898838
## b_timebin7:prime3 -2.0636400 0.4873454 -3.0413115 -1.0857520
## b_timebin8:prime3 -2.6676000 0.5286359 -3.6544872 -1.5199973
## b_timebin9:prime3 -2.4215700 0.2800557 -3.0239505 -1.7384690
## b_timebin10:prime3 -1.8555650 0.4168923 -2.8024945 -1.0426678
## b_timebin11:prime3 -1.3114950 0.2997150 -2.0448833 -0.7438644
## b_timebin12:prime3 -1.0061350 0.2263708 -1.5681517 -0.5927057
## b_timebin13:prime3 -0.9334835 0.2431501 -1.5733905 -0.4806148
## b_timebin14:prime3 -1.0480200 0.2188021 -1.6111005 -0.6284875
## b_timebin15:prime3 -1.1136550 0.3581458 -1.9308520 -0.4511871
post_qi_b <- post %>%
select(starts_with("b_")) %>%
pivot_longer(everything()) %>%
group_by(name) %>%
median_qi(value) %>%
arrange(name)
head(post_qi_b) # 30 "fixed" effects
## # A tibble: 6 × 7
## name value .lower .upper .width .point .interval
## <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 b_timebin10:prime1 -0.582 -1.28 -0.138 0.95 median qi
## 2 b_timebin10:prime2 -0.866 -1.54 -0.365 0.95 median qi
## 3 b_timebin10:prime3 -1.86 -2.80 -1.04 0.95 median qi
## 4 b_timebin11:prime1 -0.436 -1.01 -0.114 0.95 median qi
## 5 b_timebin11:prime2 -0.777 -1.48 -0.319 0.95 median qi
## 6 b_timebin11:prime3 -1.31 -2.04 -0.744 0.95 median qi
Visualise fixed effects.
tidy_fixed <- post %>%
select(starts_with("b_"), .chain, .iteration, .draw) %>%
rename(chain=.chain, iter=.iteration, draw=.draw) %>%
pivot_longer(-c(chain, draw, iter)) %>%
mutate(timebin = str_sub(name,10,11),
timebin = factor(str_remove(timebin,":"),levels=c(6:15)),
condition = str_sub(name,17,18),
condition = factor(str_remove(condition,"e"),
levels=c(1,2,3),
labels=c("blank","congruent","incongruent")))
head(tidy_fixed)
## # A tibble: 6 × 7
## chain iter draw name value timebin condition
## <int> <int> <int> <chr> <dbl> <fct> <fct>
## 1 1 1 1 b_timebin6:prime1 -4.63 6 blank
## 2 1 1 1 b_timebin7:prime1 -2.75 7 blank
## 3 1 1 1 b_timebin8:prime1 -1.62 8 blank
## 4 1 1 1 b_timebin9:prime1 -0.900 9 blank
## 5 1 1 1 b_timebin10:prime1 -0.151 10 blank
## 6 1 1 1 b_timebin11:prime1 -0.449 11 blank
tail(tidy_fixed)
## # A tibble: 6 × 7
## chain iter draw name value timebin condition
## <int> <int> <int> <chr> <dbl> <fct> <fct>
## 1 4 2000 8000 b_timebin10:prime3 -2.22 10 incongruent
## 2 4 2000 8000 b_timebin11:prime3 -0.981 11 incongruent
## 3 4 2000 8000 b_timebin12:prime3 -1.10 12 incongruent
## 4 4 2000 8000 b_timebin13:prime3 -1.05 13 incongruent
## 5 4 2000 8000 b_timebin14:prime3 -1.07 14 incongruent
## 6 4 2000 8000 b_timebin15:prime3 -1.19 15 incongruent
# plot
p_tidy_fixed <- ggplot(tidy_fixed, aes(x = timebin, y = value, fill=condition)) +
stat_halfeye(point_interval="median_qi",
.width = c(0.80,0.95),
alpha=0.7) +
labs(title = 'Posterior distributions for population-level\neffects in Model M1i',
x = "time bin", y = "cloglog-hazard") +
scale_fill_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
facet_wrap(~condition)
p_tidy_fixed
ggsave("Tutorial_2_Bayesian/figures/M1i_postdistr.png", width = 10, height = 8, dpi = 800)
We can plot the expected value of the posterior predictive distribution – the predicted hazard values – for the average participant, for each participant in the data set, and for a brand new hypothetical participant.
First, for the average participant, using add_epred_draws().
dat_M1i <- as_tibble(model_M1i$data)
epreds_grand <- dat_M1i %>%
data_grid(timebin, prime) %>%
add_epred_draws(model_M1i,
re_formula = NA) %>%
mutate(prime = factor(prime,
levels=c(1,2,3),
labels=c("blank","congruent","incongruent"))) %>%
ungroup()
Summarize and plot predicted hazard values.
epreds_grand %>%
group_by(timebin,prime) %>%
median_qi(.width=0.95)
## # A tibble: 30 × 8
## timebin prime .epred .lower .upper .width .point .interval
## <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 6 blank 0.00986 0.00400 0.0704 0.95 median qi
## 2 6 congruent 0.0974 0.0553 0.186 0.95 median qi
## 3 6 incongruent 0.0989 0.0466 0.221 0.95 median qi
## 4 7 blank 0.0468 0.0234 0.139 0.95 median qi
## 5 7 congruent 0.194 0.122 0.295 0.95 median qi
## 6 7 incongruent 0.119 0.0466 0.287 0.95 median qi
## 7 8 blank 0.138 0.0805 0.229 0.95 median qi
## 8 8 congruent 0.245 0.149 0.364 0.95 median qi
## 9 8 incongruent 0.0671 0.0255 0.196 0.95 median qi
## 10 9 blank 0.273 0.167 0.396 0.95 median qi
## # ℹ 20 more rows
ggplot(epreds_grand, aes(x=timebin, y=.epred,
fill=prime, color=prime)) +
stat_lineribbon(point_interval="median_qi",.width=.95,alpha = 0.5) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
scale_y_continuous(limits=c(0,.6)) +
labs(y = "predicted hazard") +
ggtitle("Average participant")
## Warning: Removed 743 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
ggsave("Tutorial_2_Bayesian/figures/M1i_pred_grand.png", width = 10, height = 8, dpi = 800)
Again, for the average participant, but now using posterior draws.
tidy_fixed %>%
mutate(haz = 1-exp((-1)*exp(value))) %>% # inverse cloglog
ggplot(aes(x=timebin, y=haz,
fill=condition, color=condition)) +
stat_lineribbon(point_interval="median_qi",
.width=.95,
alpha = 0.5) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
scale_y_continuous(limits=c(0,.6)) +
labs(y = "predicted hazard") +
ggtitle("Average participant")
## Warning: Removed 743 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
Second, for each participant in the data set.
epreds_pid <- dat_M1i %>%
data_grid(pid,timebin, prime) %>%
add_epred_draws(model_M1i,
re_formula = NULL) %>%
mutate(prime = factor(prime,
levels=c(1,2,3),
labels=c("blank","congruent","incongruent"))) %>%
ungroup()
Summarize and plot predicted hazard values.
epreds_pid %>%
group_by(pid,timebin,prime) %>%
median_qi(.width=0.95)
## # A tibble: 180 × 9
## pid timebin prime .epred .lower .upper .width .point .interval
## <dbl> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 1 6 blank 0.00948 0.00354 0.0271 0.95 median qi
## 2 1 6 congruent 0.0744 0.0400 0.123 0.95 median qi
## 3 1 6 incongruent 0.0742 0.0393 0.125 0.95 median qi
## 4 1 7 blank 0.0665 0.0390 0.107 0.95 median qi
## 5 1 7 congruent 0.176 0.117 0.246 0.95 median qi
## 6 1 7 incongruent 0.101 0.0557 0.166 0.95 median qi
## 7 1 8 blank 0.222 0.170 0.284 0.95 median qi
## 8 1 8 congruent 0.272 0.194 0.363 0.95 median qi
## 9 1 8 incongruent 0.0375 0.0137 0.0812 0.95 median qi
## 10 1 9 blank 0.404 0.331 0.481 0.95 median qi
## # ℹ 170 more rows
ggplot(epreds_pid, aes(x=timebin, y=.epred,
fill=prime, color=prime)) +
stat_lineribbon(point_interval="median_qi",
.width=.95,
alpha = 0.5) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
scale_y_continuous(limits=c(0,1)) +
labs(y = "predicted hazard") +
ggtitle("Each participant (N = 6)") +
facet_wrap(~pid, ncol=3)
ggsave("Tutorial_2_Bayesian/figures/M1i_pred_pid.png", width = 10, height = 8, dpi = 800)
Third, for a completely new hypothetical participant.
epreds_pid_new <- dat_M1i %>%
data_grid(timebin, prime) %>%
mutate(pid = 7) %>%
add_epred_draws(model_M1i,
re_formula = NULL,
allow_new_levels=T,
sample_new_levels = "gaussian") %>%
mutate(prime = factor(prime,
levels=c(1,2,3),
labels=c("blank","congruent","incongruent"))) %>%
ungroup()
Summarize and plot predicted hazard values.
epreds_pid_new %>%
group_by(pid,timebin,prime) %>%
median_qi(.width=.95)
## # A tibble: 30 × 9
## pid timebin prime .epred .lower .upper .width .point .interval
## <dbl> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 7 6 blank 0.00882 0.00111 0.390 0.95 median qi
## 2 7 6 congruent 0.0974 0.0207 0.435 0.95 median qi
## 3 7 6 incongruent 0.0974 0.0119 0.669 0.95 median qi
## 4 7 7 blank 0.0462 0.00550 0.512 0.95 median qi
## 5 7 7 congruent 0.193 0.0584 0.554 0.95 median qi
## 6 7 7 incongruent 0.120 0.00516 0.965 0.95 median qi
## 7 7 8 blank 0.137 0.0304 0.498 0.95 median qi
## 8 7 8 congruent 0.246 0.0592 0.663 0.95 median qi
## 9 7 8 incongruent 0.0664 0.00334 0.838 0.95 median qi
## 10 7 9 blank 0.276 0.0710 0.701 0.95 median qi
## # ℹ 20 more rows
ggplot(epreds_pid_new, aes(x=timebin, y=.epred,
fill=prime, color=prime)) +
stat_lineribbon(point_interval="median_qi",
.width=.95,
alpha = 0.5) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
scale_y_continuous(limits=c(0,1)) +
labs(y = "predicted hazard") +
ggtitle("A brand new hypothetical participant")
ggsave("Tutorial_2_Bayesian/figures/M1i_pred_pid_new.png", width = 10, height = 8, dpi = 800)
We are actually interested in the difference in predicted hazard between congruent and blank primes on the one hand, and between incongruent and blank primes on the other hand, for each time bin. When based on the grand mean, this is known as the grand average marginal effect (AME).
epreds_grand_diffs <- epreds_grand %>%
pivot_wider(id_cols = c(timebin, .draw),
names_from = "prime",
values_from = ".epred") %>% #80000 x 5
mutate(`congruent minus blank` = congruent - blank,
`incongruent minus blank` = incongruent - blank) %>%
select(-c(blank,congruent,incongruent)) %>%
pivot_longer(cols = c(`congruent minus blank`,`incongruent minus blank`),
names_to = "contrast",
values_to = "diff")
epreds_grand_diffs # 160 000 rows
## # A tibble: 160,000 × 4
## timebin .draw contrast diff
## <fct> <int> <chr> <dbl>
## 1 6 1 congruent minus blank 0.0834
## 2 6 1 incongruent minus blank 0.0805
## 3 6 2 congruent minus blank 0.149
## 4 6 2 incongruent minus blank 0.0342
## 5 6 3 congruent minus blank 0.0982
## 6 6 3 incongruent minus blank 0.0861
## 7 6 4 congruent minus blank 0.0592
## 8 6 4 incongruent minus blank 0.0476
## 9 6 5 congruent minus blank 0.0403
## 10 6 5 incongruent minus blank 0.0709
## # ℹ 159,990 more rows
Summarize and plot.
table <- epreds_grand_diffs %>%
group_by(contrast,timebin) %>%
mean_qi(.width = c(.95)) %>%
arrange(contrast,timebin,.width) %>%
select(-c(.width,.point,.interval))
write_csv(table, file="Tutorial_2_Bayesian/tables/grand_AMEs.csv")
ggplot(epreds_grand_diffs, aes(x=timebin, y=diff,
fill=contrast)) +
stat_lineribbon(alpha = 0.5,
point_interval = "mean_qi",
.width=c(.8,.95)) +
geom_hline(yintercept = 0,
color = "red",
lty = 3) +
scale_y_continuous(limits=c(-0.5,0.3)) +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
theme(axis.text.x = element_text(angle=90)) +
labs(y = "difference in predicted hazard") +
ggtitle("Grand average marginal effect") +
facet_wrap(~contrast)
## Warning: Removed 357 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 262 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
ggsave("Tutorial_2_Bayesian/figures/M1i_grand_AME.png", width = 10, height = 8, dpi = 800)
Note that these grand AMEs can also be calculated based on the posterior draws.
post %>% # 8000 x 30 = 240000
mutate_all(function(x){1-exp((-1)*exp(x))}) %>%
mutate(bin6_cb = `b_timebin6:prime2` - `b_timebin6:prime1`,
bin7_cb = `b_timebin7:prime2` - `b_timebin7:prime1`,
bin8_cb = `b_timebin8:prime2` - `b_timebin8:prime1`,
bin9_cb = `b_timebin9:prime2` - `b_timebin9:prime1`,
bin10_cb = `b_timebin10:prime2` - `b_timebin10:prime1`,
bin11_cb = `b_timebin11:prime2` - `b_timebin11:prime1`,
bin12_cb = `b_timebin12:prime2` - `b_timebin12:prime1`,
bin13_cb = `b_timebin13:prime2` - `b_timebin13:prime1`,
bin14_cb = `b_timebin14:prime2` - `b_timebin14:prime1`,
bin15_cb = `b_timebin15:prime2` - `b_timebin15:prime1`,
bin6_ib = `b_timebin6:prime3` - `b_timebin6:prime1`,
bin7_ib = `b_timebin7:prime3` - `b_timebin7:prime1`,
bin8_ib = `b_timebin8:prime3` - `b_timebin8:prime1`,
bin9_ib = `b_timebin9:prime3` - `b_timebin9:prime1`,
bin10_ib = `b_timebin10:prime3` - `b_timebin10:prime1`,
bin11_ib = `b_timebin11:prime3` - `b_timebin11:prime1`,
bin12_ib = `b_timebin12:prime3` - `b_timebin12:prime1`,
bin13_ib = `b_timebin13:prime3` - `b_timebin13:prime1`,
bin14_ib = `b_timebin14:prime3` - `b_timebin14:prime1`,
bin15_ib = `b_timebin15:prime3` - `b_timebin15:prime1`) %>%
select(starts_with("bin")) %>% # 8000 x 20
pivot_longer(cols = bin6_cb:bin15_ib,
names_to = "condition",
values_to = "diff") %>%
mutate(bin = str_sub(condition,4,5),
bin = str_remove(bin,"_"),
bin = factor(bin, levels=c(6:15)),
comp = str_sub(condition,6,8),
comp = str_remove(comp,"_"),
contrast = ifelse(comp == "cb",
"congruent minus blank", "incongruent minus blank")) %>%
ggplot(aes(x=bin, y=diff, fill=contrast)) +
stat_lineribbon(alpha = 0.5,
point_interval = "mean_qi",
.width=c(.8,.95)) +
geom_hline(yintercept = 0,
color = "red",
lty = 3) +
scale_y_continuous(limits=c(-0.5,0.3)) +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
theme(axis.text.x = element_text(angle=90)) +
labs(y = "difference in predicted hazard", x = "timebin") +
ggtitle("Grand average marginal effect") +
facet_wrap(~contrast)
## Warning: Removed 357 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 262 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
To calculate the subject-specific AMEs for each time bin, we create contrasts between the conditional means.
epreds_pid_diffs <- epreds_pid %>%
pivot_wider(id_cols = c(pid,timebin, .draw),
names_from = "prime",
values_from = ".epred") %>% #80000 x 5
mutate(`congruent minus blank` = congruent - blank,
`incongruent minus blank` = incongruent - blank) %>%
select(-c(blank,congruent,incongruent)) %>%
pivot_longer(cols = c(`congruent minus blank`,`incongruent minus blank`),
names_to = "contrast",
values_to = "diff")
epreds_pid_diffs # 160 000 rows
## # A tibble: 960,000 × 5
## pid timebin .draw contrast diff
## <dbl> <fct> <int> <chr> <dbl>
## 1 1 6 1 congruent minus blank 0.0567
## 2 1 6 1 incongruent minus blank 0.117
## 3 1 6 2 congruent minus blank 0.0370
## 4 1 6 2 incongruent minus blank 0.0111
## 5 1 6 3 congruent minus blank 0.0603
## 6 1 6 3 incongruent minus blank 0.0643
## 7 1 6 4 congruent minus blank 0.113
## 8 1 6 4 incongruent minus blank 0.0344
## 9 1 6 5 congruent minus blank 0.0286
## 10 1 6 5 incongruent minus blank 0.0748
## # ℹ 959,990 more rows
Summarize the contrasts in predicted hazard values.
table <- epreds_pid_diffs %>%
group_by(pid,timebin,contrast) %>%
mean_qi(.width = .95) %>%
arrange(contrast,timebin,.width) %>%
select(-c(.width,.point,.interval))
write_csv(table, file="Tutorial_2_Bayesian/tables/pid_AMEs.csv")
Plot the contrasts in predicted hazard values.
ggplot(epreds_pid_diffs, aes(x=timebin, y=diff,
fill=contrast, color=contrast)) +
stat_lineribbon(alpha = 0.5,
point_interval = "mean_qi",
.width=c(.8,.95)) +
#stat_halfeye(point_interval = "mean_qi",.width=c(.8,.95)) +
geom_hline(yintercept = 0, color = "red", lty = 3) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
scale_y_continuous(limits=c(-0.6,0.6)) +
labs(y = "difference in predicted hazard") +
ggtitle("Subject-specific AMEs") +
facet_wrap(~pid, ncol=3)
## Warning: Removed 1455 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 2308 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 35 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 3578 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 163 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 498 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
ggsave("Tutorial_2_Bayesian/figures/M1i_pid_AMEs.png", width = 10, height = 8, dpi = 800)
[[remove next section and 2 code chunks….]] And check whether an average of subject-specific AMEs is equal to the grand AME… NO !!!
epreds_pid_diffs_av <- epreds_pid_diffs %>% group_by(timebin,contrast,.draw) %>%
mutate(average_condmeans = mean(diff))
ggplot(epreds_pid_diffs_av, aes(x=timebin, y=diff,
fill=contrast)) +
stat_lineribbon(alpha = 0.5, point_interval = "mean_qi",.width=c(.8,.95)) +
# stat_halfeye(point_interval = "mean_qi",.width=c(.8,.95))+
geom_hline(yintercept = 0, color = "red", lty = 3) +
scale_y_continuous(limits=c(-0.5,0.3)) +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
theme(axis.text.x = element_text(angle=90)) +
labs(y = "difference in predicted hazard") +
ggtitle("Average of subject-specific AMEs") +
facet_wrap(~contrast)
## Warning: Removed 27766 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
## Warning: Removed 19624 rows containing missing values or values outside the scale range
## (`stat_slabinterval()`).
epreds_pid_new_diffs <- epreds_pid_new %>%
pivot_wider(id_cols = c(pid,timebin,.draw),
names_from = "prime",
values_from = ".epred") %>%
mutate(`congruent minus blank` = congruent - blank,
`incongruent minus blank` = incongruent - blank) %>%
select(-c(blank,congruent,incongruent)) %>%
pivot_longer(cols = c(`congruent minus blank`,`incongruent minus blank`),
names_to = "contrast",
values_to = "diff")
epreds_pid_new_diffs
## # A tibble: 160,000 × 5
## pid timebin .draw contrast diff
## <dbl> <fct> <int> <chr> <dbl>
## 1 7 6 1 congruent minus blank 0.0668
## 2 7 6 1 incongruent minus blank 0.105
## 3 7 6 2 congruent minus blank 0.221
## 4 7 6 2 incongruent minus blank -0.00898
## 5 7 6 3 congruent minus blank 0.100
## 6 7 6 3 incongruent minus blank 0.121
## 7 7 6 4 congruent minus blank 0.0733
## 8 7 6 4 incongruent minus blank 0.199
## 9 7 6 5 congruent minus blank 0.0574
## 10 7 6 5 incongruent minus blank 0.0383
## # ℹ 159,990 more rows
Summarize and plot the contrasts in predicted hazard values.
epreds_pid_new_diffs %>%
group_by(pid,timebin,contrast) %>%
median_qi(.width = .95) %>%
arrange(contrast,timebin,.width) %>%
select(-c(.width,.point,.interval))
## # A tibble: 20 × 6
## pid timebin contrast diff .lower .upper
## <dbl> <fct> <chr> <dbl> <dbl> <dbl>
## 1 7 6 congruent minus blank 0.0794 -0.275 0.411
## 2 7 7 congruent minus blank 0.135 -0.314 0.500
## 3 7 8 congruent minus blank 0.101 -0.295 0.518
## 4 7 9 congruent minus blank 0.0757 -0.439 0.640
## 5 7 10 congruent minus blank -0.0781 -0.629 0.515
## 6 7 11 congruent minus blank -0.102 -0.523 0.390
## 7 7 12 congruent minus blank -0.0441 -0.408 0.336
## 8 7 13 congruent minus blank -0.0731 -0.539 0.484
## 9 7 14 congruent minus blank -0.105 -0.551 0.451
## 10 7 15 congruent minus blank -0.00386 -0.716 0.636
## 11 7 6 incongruent minus blank 0.0789 -0.244 0.635
## 12 7 7 incongruent minus blank 0.0590 -0.342 0.866
## 13 7 8 incongruent minus blank -0.0567 -0.418 0.665
## 14 7 9 incongruent minus blank -0.180 -0.616 0.139
## 15 7 10 incongruent minus blank -0.247 -0.744 0.461
## 16 7 11 incongruent minus blank -0.223 -0.614 0.368
## 17 7 12 incongruent minus blank -0.169 -0.515 0.285
## 18 7 13 incongruent minus blank -0.0987 -0.533 0.422
## 19 7 14 incongruent minus blank -0.0677 -0.499 0.340
## 20 7 15 incongruent minus blank 0.0184 -0.697 0.718
ggplot(epreds_pid_new_diffs, aes(x=timebin, y=diff,
fill=contrast, color=contrast)) +
stat_lineribbon(alpha = 0.5,
point_interval = "mean_qi",
.width=c(.95)) +
#stat_halfeye(point_interval = "mean_qi",.width=c(.8,.95)) +
geom_hline(yintercept = 0, color = "red", lty = 3) +
scale_fill_brewer(palette = "Dark2") +
scale_color_brewer(palette = "Dark2") +
scale_x_discrete(labels = str_c("(",(c(6:15)-1)*40,",",c(6:15)*40,"]"), breaks = 6:15) +
theme(axis.text.x = element_text(angle=90)) +
scale_y_continuous(limits=c(-1,1)) +
labs(y = "difference in predicted hazard") +
ggtitle("AMEs for a brand new\nhypothetical participant") +
facet_wrap(~contrast)
ggsave("Tutorial_2_Bayesian/figures/M1i_pid_new_AMEs.png", width = 10, height = 8, dpi = 800)
What can we conclude from model M1i about our research question, i.e., the temporal dynamics of the effect of prime-target congruency on RT? In other words, in which of the 40-ms time bins between 200 and 600 ms after target onset does changing the prime from blank to congruent or incongruent affect the hazard of response occurrence (for a prime-target SOA of 187 ms)?
If we want to study the average effect of prime type on hazard, uncontaminated by inter-individual differences, we can base our conclusion on Figure 8 and Table 4. The contrast “congruent minus blank” was estimated to be 0.09 hazard units in bin 6 (95% CrI = [0.02, 0.17]), and 0.14 hazard units in bin 7 (95% CrI = [0.04, 0.25]). For the other bins, the 95% credible interval contained zero. The contrast “incongruent minus blank” was estimated to be 0.09 hazard units in bin 6 (95% CrI = [0.01, 0.21]), -0.19 hazard units in bin 9 (95% CrI = [-0.31, -0.06]), -0.27 hazard units in bin 10 (95% CrI = [-0.45, -0.04]), and -0.23 hazard units in bin 11 (95% CrI = [-0.40, -0.03]). For the other bins, the 95% credible interval contained zero. Note that we could also have calculated hazard ratios instead of hazard differences.
There are thus two phases of performance for the average person between 200 and 600 ms after target onset. In the first phase, the addition of a congruent or incongruent prime stimulus increases the hazard of response occurrence compared to blank prime trials in the time period (200, 240]. In the second phase, only the incongruent prime decreases the hazard of response occurrence compared to blank primes, in the time period (320,440]. The sign of the effect of incongruent primes on the hazard of response occurrrence thus depends on how much waiting time has passed since target onset.
The posterior distribution of each contrast can also be summarized by considering its proportion below or above some value, like zero. For example, here are the proportions that each contrast is larger and smaller than 0:
pabove <- epreds_grand_diffs %>%
group_by(timebin,contrast) %>%
summarize(prop_above = mean(diff > 0)) %>%
arrange(contrast,timebin)
## `summarise()` has grouped output by 'timebin'. You can override using the
## `.groups` argument.
pbelow <- epreds_grand_diffs %>%
group_by(timebin,contrast) %>%
summarize(prop_below = mean(diff < 0)) %>%
arrange(contrast,timebin)
## `summarise()` has grouped output by 'timebin'. You can override using the
## `.groups` argument.
table_prop <- pabove %>% inner_join(pbelow, by = c("timebin","contrast"))
write_csv(table_prop, file="Tutorial_2_Bayesian/tables/contrasts_props.csv")
Thus, the probability that the contrast “congruent minus blank” is larger than 0, is larger than .9 in bins 6 to 8. And the probability that the contrast “incongruent minus blank” is smaller than 0, is larger than .9 in bins 9 to 12.
If we want to focus more on inter-individual differences, we can study the subject-specific differences in hazard in Figure 9. Note that three participants (1, 2, and 3) show a negative difference for the contrast “congruent minus incongruent” in bin (360,400] – subject 2 also in bin (320,360].
Future studies could (a) increase the number of participants to estimate the proportion of “dippers” in the subject population, and/or (b) try to explain why this dip occurs. For example, @panisWhatShapingRT2016 concluded that active, top-down, task-guided response inhibition effects emerge around 360 ms after the onset of the stimulus following the prime (here: the target stimulus). Such a top-down inhibitory effect might exist in our priming data set, because after some time participants will learn that the first stimulus is not the one they have to respond to; To prevent a premature overt response to the prime they thus might gradually increase a global response threshold during the remainder of the experiment, which could result in a lower hazard in congruent trials compared to blank trials, for bins after ~360 ms, and towards the end of the experiment. This effect might be masked for incongruent primes by the response competition effect.
Interestingly, all subjects show a tendency in their mean difference (congruent minus blank) to “dip” around that time (Figure 9). Therefore, future modeling efforts could incorporate the trial number into the model formula, in order to also study how the effects of prime type on hazard change on the long experiment-wide time scale, next to the short trial-wide time scale. In Tutorial_2a.Rmd we provide a number of model formula that should get you going.
probability <- (1:99999)/100000
logistic <- function(x) { return( 1/(1+exp(-1*x)) )}
logit <- function(x) { return( log(x/(1-x)) )}
inverse_cloglog <- function(x) { return( 1-(exp(-1*exp(x))) )}
cloglog <- function(x) { return( log(-1*log(1-x)) )}
cloglog <- cloglog(probability)
logit <- logit(probability)
dataplot <- cbind(probability,cloglog,logit) %>%
as_tibble() %>%
pivot_longer(cloglog:logit, names_to = "link", values_to = "value")
ggplot() +
geom_hline(yintercept=0, color="grey") +
geom_line(data=dataplot,aes(y=value,x=probability,colour=link),linewidth=1) +
geom_line(data=dataplot,aes(y=value,x=probability,colour=link),linewidth=1) +
scale_color_brewer(palette = "Dark2") +
geom_vline(xintercept = logistic(0), linetype="dotted", linewidth = 0.3) +
geom_vline(xintercept = inverse_cloglog(0), linetype="dotted", linewidth = 0.3) +
annotate("text", x = logistic(0)-.02, y = -6, label = "logistic(0) = 0.5", angle = 90, size=4) +
annotate("text", x = inverse_cloglog(0)+.02, y = -6, label = "inverse_cloglog(0) = 0.6321", angle=90,size=4) +
labs(x = "Probability",
y = "logit or cloglog scale") +
theme(panel.grid = element_blank())
ggsave("Tutorial_2_Bayesian/figures/linkfunctions.png", width = 8, height = 8, dpi = 800)
To gain a sense of what prior logit values would approximate a uniform distribution on the probability (i.e., discrete-time hazard) scale, Solomon Kurz simulated a large number of draws from the Uniform(0,1) distribution, converted those draws to the log-odds metric, and fitted a Student’s t model. Here we do the same for prior cloglog values: simulate a large number of draws from U(0,1), convert them to the cloglog metric, and fit a skew-normal model (due to the asymmetry of the cloglog link function), to gain a sense of what prior cloglog values would approximate a uniform distribution on the probability (i.e., discrete-time hazard) scale.
set.seed(11)
logit <- function(x) { return( log(x/(1-x)) )}
cloglog <- function(x) { return( log(-1*log(1-x)) )}
# generate draws from U(0,1) and convert
dat <-
tibble(p = runif(1e6, 0, 1)) %>%
mutate(g = logit(p),
c = cloglog(p))
# display
dat %>%
ggplot(aes(x = c)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
# fit model
fit_skewN <-
brm(data = dat,
family = skew_normal(),
c ~ 1,
chains = 4, cores = 4,
file = "Tutorial_2_Bayesian/models/fit_skewN")
fit_skewN <- readRDS("Tutorial_2_Bayesian/models/fit_skewN.rds")
summary(fit_skewN)
## Family: skew_normal
## Links: mu = identity; sigma = identity; alpha = identity
## Formula: c ~ 1
## Data: dat (Number of observations: 1000000)
## Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup draws = 4000
##
## Regression Coefficients:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept -0.59 0.00 -0.59 -0.59 1.00 3127 2835
##
## Further Distributional Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 1.26 0.00 1.26 1.26 1.00 2951 2825
## alpha -4.22 0.01 -4.25 -4.19 1.00 1142 1319
##
## Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Now we can reverse the process. We simulate from the skew-Normal distribution based on the posterior means for mu, sigma, and alpha, and then convert the results into the probability (i.e., discrete-time hazard) metric.
set.seed(11)
inverse_cloglog <- function(x) { return( 1-(exp(-1*exp(x))) )}
tibble(c = rskew_normal(1e6, mu=-0.59, sigma = 1.26, alpha = -4.22) ) %>%
mutate(p = inverse_cloglog(c)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
logistic <- function(x) { return( 1/(1+exp(-1*x)) )}
logit <- function(x) { return( log(x/(1-x)) )}
inverse_cloglog <- function(x) { return( 1-(exp(-1*exp(x))) )}
cloglog <- function(x) { return( log(-1*log(1-x)) )}
set.seed(23)
# A N(0,4) prior on the logit and cloglog scales pushes mass to probabilities of 0 and 1
pr1 <- tibble(prior = rnorm(1e6, mean = 0, sd = 4)) %>%
ggplot(aes(x = prior)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
annotate(geom="text", x=-13, y=60000, label="N(0,4)",
color="red", size=4) +
annotate(geom = 'text', label = 'A', x = -Inf, y = Inf, hjust = 0, vjust = 1, size=8)+
theme(panel.grid = element_blank())
l1 <- tibble(log_odds = rnorm(1e6, mean = 0, sd = 4)) %>%
mutate(p = logistic(log_odds)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
ggtitle("logistic(prior)") +
theme(panel.grid = element_blank())
c1 <- tibble(cloglog_prob = rnorm(1e6, mean = 0, sd = 4)) %>%
mutate(p = inverse_cloglog(cloglog_prob)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
ggtitle("inverse_cloglog(prior)")+
theme(panel.grid = element_blank())
# A N(0,2) prior on the logit and cloglog scales pushes mass to probabilities of 0 and/or 1
pr2 <- tibble(prior = rnorm(1e6, mean = 0, sd = 2)) %>%
ggplot(aes(x = prior)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
annotate(geom="text", x=-6, y=60000, label="N(0,2)",
color="red", size=4) +
annotate(geom = 'text', label = 'B', x = -Inf, y = Inf, hjust = 0, vjust = 1, size=8)+
theme(panel.grid = element_blank())
l2 <- tibble(log_odds = rnorm(1e6, mean = 0, sd = 2)) %>%
mutate(p = logistic(log_odds)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
c2 <- tibble(cloglog_prob = rnorm(1e6, mean = 0, sd = 2)) %>%
mutate(p = inverse_cloglog(cloglog_prob)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
# A student-t(df=7.61) prior with scale 1.57 on the logit scale approximates a uniform distribution on the probability scale. This might be a good prior to use for the alpha parameters or Intercept in a logit-hazard model.
pr3 <- tibble(prior = rt(1e6, df = 7.61)* 1.57) %>%
ggplot(aes(x = prior)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
annotate(geom="text", x=-12, y=140000, label="t(7.61,0,1.57)",
color="red", size=4) +
annotate(geom = 'text', label = 'C', x = -Inf, y = Inf, hjust = 0, vjust = 1, size=8)+
theme(panel.grid = element_blank())
l3 <- tibble(log_odds = rt(1e6, df = 7.61)* 1.57) %>%
mutate(p = logistic(log_odds)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
c3 <- tibble(cloglog_prob = rt(1e6, df = 7.61)* 1.57) %>%
mutate(p = inverse_cloglog(cloglog_prob)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
# A Normal(0,1) prior on the logit scale gently regularizes p towards .5.
pr4 <- tibble(prior = rnorm(1e6, mean = 0, sd = 1))%>%
ggplot(aes(x = prior)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
annotate(geom="text", x=-3, y=60000, label="N(0,1)",
color="red", size=4) +
annotate(geom = 'text', label = 'D', x = -Inf, y = Inf, hjust = 0, vjust = 1, size=8)+
theme(panel.grid = element_blank())
l4 <- tibble(log_odds = rnorm(1e6, mean = 0, sd = 1))%>%
mutate(p = logistic(log_odds)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
#geom_vline(xintercept=logistic(0), color="red") +
theme(panel.grid = element_blank())
c4 <- tibble(cloglog_prob = rnorm(1e6, mean = 0, sd = 1))%>%
mutate(p = inverse_cloglog(cloglog_prob)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
# A skew_Normal(-0.59,1.26,-4.22) prior on the cloglog scale approxiates a uniform distr. on the hazard scale. This uninformative prior might be good for the alpha parameters or Intercept in a cloglog-hazard model.
pr5 <- tibble(prior = rskew_normal(1e6, mu=-0.59, sigma = 1.26, alpha = -4.22)) %>%
ggplot(aes(x = prior)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
annotate(geom="text", x=-5.5, y=55000, label="skew_N(-0.59,1.26,-4.22)",
color="red", size=4) +
annotate(geom = 'text', label = 'E', x = -Inf, y = Inf, hjust = 0, vjust = 1, size=8)+
theme(panel.grid = element_blank())
l5 <- ggplot()
c5 <- tibble(cloglog_prob = rskew_normal(1e6, mu=-0.59, sigma = 1.26, alpha = -4.20)) %>%
mutate(p = inverse_cloglog(cloglog_prob)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
# The skew_Normal(-1,1,-2) on the cloglog scale is a weakly informative prior for the alpha parameters or Intercept in a cloglog-hazard model because hazard values below .5 more likely than values above .5 in general.
pr6 <- tibble(prior = rskew_normal(1e6, mu=-1, sigma = 1, alpha = -2)) %>%
ggplot(aes(x = prior)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
annotate(geom="text", x=-5, y=55000, label="skew_N(-1,1,-2)",
color="red", size=4) +
annotate(geom = 'text', label = 'F', x = -Inf, y = Inf, hjust = 0, vjust = 1, size=8)+
theme(panel.grid = element_blank())
l6 <- ggplot()
c6 <- tibble(cloglog_prob = rskew_normal(1e6, mu=-1, sigma = 1, alpha = -2)) %>%
mutate(p = inverse_cloglog(cloglog_prob)) %>%
ggplot(aes(x = p)) +
geom_histogram(bins = 50) +
scale_y_continuous(NULL, breaks = NULL) +
theme(panel.grid = element_blank())
((l1 + pr1 + c1) / (l2 + pr2 + c2) / (l3 + pr3 + c3) / (l4 + pr4 + c4)/ (l5 + pr5 + c5) / (l6 + pr6 + c6) ) &
theme(text = element_text(size = 10, face = "bold"),
title = element_text(size = 10, face = "bold"))
ggsave("Tutorial_2_Bayesian/figures/plot_of_priors.png", width=14, height=13,dpi=800)